Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
·
2ce9a1a
1
Parent(s):
b368114
Update with h2oGPT hash ad9d685b188cece0b9c69716ea8e320b74f0caf7
Browse files- client_test.py +6 -3
- enums.py +16 -6
- evaluate_params.py +4 -0
- gen.py +135 -43
- gpt4all_llm.py +18 -8
- gpt_langchain.py +174 -112
- gradio_runner.py +456 -178
- gradio_utils/__init__.py +0 -0
- gradio_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/css.cpython-310.pyc +0 -0
- gradio_utils/css.py +4 -0
- h2oai_pipeline.py +4 -1
- iterators/__pycache__/timeout_iterator.cpython-310.pyc +0 -0
- iterators/timeout_iterator.py +1 -1
- prompter.py +28 -0
- requirements.txt +16 -16
- utils.py +83 -6
client_test.py
CHANGED
@@ -48,7 +48,7 @@ import markdown # pip install markdown
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
-
from enums import
|
52 |
|
53 |
debug = False
|
54 |
|
@@ -68,6 +68,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
68 |
max_new_tokens=50,
|
69 |
top_k_docs=3,
|
70 |
langchain_mode='Disabled',
|
|
|
71 |
langchain_action=LangChainAction.QUERY.value,
|
72 |
langchain_agents=[],
|
73 |
prompt_dict=None):
|
@@ -95,12 +96,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
95 |
instruction_nochat=prompt if not chat else '',
|
96 |
iinput_nochat='', # only for chat=False
|
97 |
langchain_mode=langchain_mode,
|
|
|
98 |
langchain_action=langchain_action,
|
99 |
langchain_agents=langchain_agents,
|
100 |
top_k_docs=top_k_docs,
|
101 |
chunk=True,
|
102 |
chunk_size=512,
|
103 |
-
document_subset=
|
104 |
document_choice=[],
|
105 |
)
|
106 |
from evaluate_params import eval_func_param_names
|
@@ -204,10 +206,11 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
|
|
204 |
instruction_nochat=prompt,
|
205 |
iinput_nochat='',
|
206 |
langchain_mode='Disabled',
|
|
|
207 |
langchain_action=LangChainAction.QUERY.value,
|
208 |
langchain_agents=[],
|
209 |
top_k_docs=4,
|
210 |
-
document_subset=
|
211 |
document_choice=[],
|
212 |
)
|
213 |
|
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
+
from enums import DocumentSubset, LangChainAction
|
52 |
|
53 |
debug = False
|
54 |
|
|
|
68 |
max_new_tokens=50,
|
69 |
top_k_docs=3,
|
70 |
langchain_mode='Disabled',
|
71 |
+
add_chat_history_to_context=True,
|
72 |
langchain_action=LangChainAction.QUERY.value,
|
73 |
langchain_agents=[],
|
74 |
prompt_dict=None):
|
|
|
96 |
instruction_nochat=prompt if not chat else '',
|
97 |
iinput_nochat='', # only for chat=False
|
98 |
langchain_mode=langchain_mode,
|
99 |
+
add_chat_history_to_context=add_chat_history_to_context,
|
100 |
langchain_action=langchain_action,
|
101 |
langchain_agents=langchain_agents,
|
102 |
top_k_docs=top_k_docs,
|
103 |
chunk=True,
|
104 |
chunk_size=512,
|
105 |
+
document_subset=DocumentSubset.Relevant.name,
|
106 |
document_choice=[],
|
107 |
)
|
108 |
from evaluate_params import eval_func_param_names
|
|
|
206 |
instruction_nochat=prompt,
|
207 |
iinput_nochat='',
|
208 |
langchain_mode='Disabled',
|
209 |
+
add_chat_history_to_context=True,
|
210 |
langchain_action=LangChainAction.QUERY.value,
|
211 |
langchain_agents=[],
|
212 |
top_k_docs=4,
|
213 |
+
document_subset=DocumentSubset.Relevant.name,
|
214 |
document_choice=[],
|
215 |
)
|
216 |
|
enums.py
CHANGED
@@ -32,25 +32,29 @@ class PromptType(Enum):
|
|
32 |
mptchat = 26
|
33 |
falcon = 27
|
34 |
guanaco = 28
|
|
|
35 |
|
36 |
|
37 |
-
class
|
38 |
Relevant = 0
|
39 |
-
|
40 |
-
|
41 |
|
42 |
|
43 |
non_query_commands = [
|
44 |
-
|
45 |
-
|
46 |
]
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
49 |
class LangChainMode(Enum):
|
50 |
"""LangChain mode"""
|
51 |
|
52 |
DISABLED = "Disabled"
|
53 |
-
CHAT_LLM = "ChatLLM"
|
54 |
LLM = "LLM"
|
55 |
ALL = "All"
|
56 |
WIKI = "wiki"
|
@@ -61,6 +65,12 @@ class LangChainMode(Enum):
|
|
61 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
class LangChainAction(Enum):
|
65 |
"""LangChain action"""
|
66 |
|
|
|
32 |
mptchat = 26
|
33 |
falcon = 27
|
34 |
guanaco = 28
|
35 |
+
llama2 = 29
|
36 |
|
37 |
|
38 |
+
class DocumentSubset(Enum):
|
39 |
Relevant = 0
|
40 |
+
RelSources = 1
|
41 |
+
TopKSources = 2
|
42 |
|
43 |
|
44 |
non_query_commands = [
|
45 |
+
DocumentSubset.RelSources.name,
|
46 |
+
DocumentSubset.TopKSources.name
|
47 |
]
|
48 |
|
49 |
|
50 |
+
class DocumentChoice(Enum):
|
51 |
+
ALL = 'All'
|
52 |
+
|
53 |
+
|
54 |
class LangChainMode(Enum):
|
55 |
"""LangChain mode"""
|
56 |
|
57 |
DISABLED = "Disabled"
|
|
|
58 |
LLM = "LLM"
|
59 |
ALL = "All"
|
60 |
WIKI = "wiki"
|
|
|
65 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
66 |
|
67 |
|
68 |
+
# modes should not be removed from visible list or added by name
|
69 |
+
langchain_modes_intrinsic = [LangChainMode.DISABLED.value,
|
70 |
+
LangChainMode.LLM.value,
|
71 |
+
LangChainMode.MY_DATA.value]
|
72 |
+
|
73 |
+
|
74 |
class LangChainAction(Enum):
|
75 |
"""LangChain action"""
|
76 |
|
evaluate_params.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
no_default_param_names = [
|
2 |
'instruction',
|
3 |
'iinput',
|
@@ -30,6 +33,7 @@ eval_func_param_names = ['instruction',
|
|
30 |
'instruction_nochat',
|
31 |
'iinput_nochat',
|
32 |
'langchain_mode',
|
|
|
33 |
'langchain_action',
|
34 |
'langchain_agents',
|
35 |
'top_k_docs',
|
|
|
1 |
+
input_args_list = ['model_state', 'my_db_state', 'selection_docs_state']
|
2 |
+
|
3 |
+
|
4 |
no_default_param_names = [
|
5 |
'instruction',
|
6 |
'iinput',
|
|
|
33 |
'instruction_nochat',
|
34 |
'iinput_nochat',
|
35 |
'langchain_mode',
|
36 |
+
'add_chat_history_to_context',
|
37 |
'langchain_action',
|
38 |
'langchain_agents',
|
39 |
'top_k_docs',
|
gen.py
CHANGED
@@ -8,7 +8,6 @@ import sys
|
|
8 |
import os
|
9 |
import time
|
10 |
import traceback
|
11 |
-
import types
|
12 |
import typing
|
13 |
import warnings
|
14 |
from datetime import datetime
|
@@ -28,12 +27,12 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
|
28 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
29 |
|
30 |
from evaluate_params import eval_func_param_names, no_default_param_names
|
31 |
-
from enums import
|
32 |
-
source_postfix, LangChainAction, LangChainAgent
|
33 |
from loaders import get_loaders
|
34 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
35 |
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
|
36 |
-
have_langchain, set_openai
|
37 |
|
38 |
start_faulthandler()
|
39 |
import_matplotlib()
|
@@ -50,8 +49,6 @@ from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
|
50 |
from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
|
51 |
from stopping import get_stopping
|
52 |
|
53 |
-
langchain_modes = [x.value for x in list(LangChainMode)]
|
54 |
-
|
55 |
langchain_actions = [x.value for x in list(LangChainAction)]
|
56 |
|
57 |
langchain_agents_list = [x.value for x in list(LangChainAgent)]
|
@@ -116,6 +113,7 @@ def main(
|
|
116 |
show_examples: bool = None,
|
117 |
verbose: bool = False,
|
118 |
h2ocolors: bool = True,
|
|
|
119 |
height: int = 600,
|
120 |
show_lora: bool = True,
|
121 |
login_mode_if_model0: bool = False,
|
@@ -147,14 +145,16 @@ def main(
|
|
147 |
langchain_action: str = LangChainAction.QUERY.value,
|
148 |
langchain_agents: list = [],
|
149 |
force_langchain_evaluate: bool = False,
|
|
|
150 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
151 |
# WIP:
|
152 |
# visible_langchain_actions: list = langchain_actions.copy(),
|
153 |
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
|
154 |
visible_langchain_agents: list = langchain_agents_list.copy(),
|
155 |
-
document_subset: str =
|
156 |
-
document_choice: list = [],
|
157 |
user_path: str = None,
|
|
|
158 |
detect_user_path_changes_every_query: bool = False,
|
159 |
use_llm_if_no_docs: bool = False,
|
160 |
load_db_if_exists: bool = True,
|
@@ -163,7 +163,10 @@ def main(
|
|
163 |
use_openai_embedding: bool = False,
|
164 |
use_openai_model: bool = False,
|
165 |
hf_embedding_model: str = None,
|
|
|
|
|
166 |
allow_upload_to_user_data: bool = True,
|
|
|
167 |
allow_upload_to_my_data: bool = True,
|
168 |
enable_url_upload: bool = True,
|
169 |
enable_text_upload: bool = True,
|
@@ -180,6 +183,7 @@ def main(
|
|
180 |
pre_load_caption_model: bool = False,
|
181 |
caption_gpu: bool = True,
|
182 |
enable_ocr: bool = False,
|
|
|
183 |
):
|
184 |
"""
|
185 |
|
@@ -259,6 +263,7 @@ def main(
|
|
259 |
:param show_examples: whether to show clickable examples in gradio
|
260 |
:param verbose: whether to show verbose prints
|
261 |
:param h2ocolors: whether to use H2O.ai theme
|
|
|
262 |
:param height: height of chat window
|
263 |
:param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
|
264 |
:param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
|
@@ -287,7 +292,7 @@ def main(
|
|
287 |
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
288 |
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
289 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
290 |
-
None: auto mode, check if langchain package exists, at least do
|
291 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
292 |
:param langchain_action: Mode langchain operations in on documents.
|
293 |
Query: Make query of document(s)
|
@@ -299,18 +304,28 @@ def main(
|
|
299 |
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
|
300 |
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
301 |
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
|
|
|
|
|
|
|
|
302 |
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
|
303 |
Expensive for large number of files, so not done by default. By default only detect changes during db loading.
|
|
|
304 |
:param visible_langchain_modes: dbs to generate at launch to be ready for LLM
|
305 |
Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
|
306 |
But wiki_full is expensive and requires preparation
|
307 |
To allow scratch space only live in session, add 'MyData' to list
|
308 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
|
|
|
|
|
|
|
|
|
|
309 |
:param visible_langchain_actions: Which actions to allow
|
310 |
:param visible_langchain_agents: Which agents to allow
|
311 |
:param document_subset: Default document choice when taking subset of collection
|
312 |
-
:param document_choice: Chosen document(s) by internal name
|
313 |
-
:param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData
|
314 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
315 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
316 |
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
@@ -321,13 +336,20 @@ def main(
|
|
321 |
Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
|
322 |
Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
|
323 |
We support automatically changing of embeddings for chroma, with a backup of db made if this is done
|
324 |
-
:param
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
:param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
|
326 |
:param enable_url_upload: Whether to allow upload from URL
|
327 |
:param enable_text_upload: Whether to allow upload of text
|
328 |
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
|
329 |
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
330 |
-
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so
|
331 |
:param top_k_docs: number of chunks to give LLM
|
332 |
:param reverse_docs: whether to reverse docs order so most relevant is closest to question.
|
333 |
Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
|
@@ -347,6 +369,9 @@ def main(
|
|
347 |
Recommended if using larger caption model
|
348 |
:param caption_gpu: If support caption, then use GPU if exists
|
349 |
:param enable_ocr: Whether to support OCR on images
|
|
|
|
|
|
|
350 |
:return:
|
351 |
"""
|
352 |
if base_model is None:
|
@@ -408,6 +433,26 @@ def main(
|
|
408 |
if langchain_mode is not None:
|
409 |
visible_langchain_modes += [langchain_mode]
|
410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
412 |
assert len(
|
413 |
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
|
@@ -421,22 +466,22 @@ def main(
|
|
421 |
# auto-set langchain_mode
|
422 |
if have_langchain and langchain_mode is None:
|
423 |
# start in chat mode, in case just want to chat and don't want to get "No documents to query" by default.
|
424 |
-
langchain_mode = LangChainMode.
|
425 |
-
if allow_upload_to_user_data and not is_public and
|
426 |
print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True)
|
427 |
elif allow_upload_to_my_data:
|
428 |
print("Auto set langchain_mode=%s. Could use MyData instead."
|
429 |
" To allow UserData to pull files from disk,"
|
430 |
-
" set user_path and ensure allow_upload_to_user_data=True" % langchain_mode,
|
|
|
431 |
else:
|
432 |
raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
|
433 |
-
if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value
|
434 |
-
LangChainMode.CHAT_LLM.value]:
|
435 |
raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
|
436 |
if langchain_mode is None:
|
437 |
# if not set yet, disable
|
438 |
langchain_mode = LangChainMode.DISABLED.value
|
439 |
-
print("Auto set langchain_mode=%s" % langchain_mode, flush=True)
|
440 |
|
441 |
if is_public:
|
442 |
allow_upload_to_user_data = False
|
@@ -547,8 +592,6 @@ def main(
|
|
547 |
|
548 |
if offload_folder:
|
549 |
makedirs(offload_folder)
|
550 |
-
if user_path:
|
551 |
-
makedirs(user_path)
|
552 |
|
553 |
placeholder_instruction, placeholder_input, \
|
554 |
stream_output, show_examples, \
|
@@ -574,7 +617,7 @@ def main(
|
|
574 |
verbose,
|
575 |
)
|
576 |
|
577 |
-
git_hash = get_githash()
|
578 |
locals_dict = locals()
|
579 |
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
580 |
if verbose:
|
@@ -588,7 +631,7 @@ def main(
|
|
588 |
get_some_dbs_from_hf()
|
589 |
dbs = {}
|
590 |
for langchain_mode1 in visible_langchain_modes:
|
591 |
-
if langchain_mode1 in ['MyData']:
|
592 |
# don't use what is on disk, remove it instead
|
593 |
for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
|
594 |
if os.path.isdir(gpath1):
|
@@ -603,7 +646,7 @@ def main(
|
|
603 |
db = prep_langchain(persist_directory1,
|
604 |
load_db_if_exists,
|
605 |
db_type, use_openai_embedding,
|
606 |
-
langchain_mode1,
|
607 |
hf_embedding_model,
|
608 |
kwargs_make_db=locals())
|
609 |
finally:
|
@@ -622,6 +665,14 @@ def main(
|
|
622 |
model_state_none = dict(model=None, tokenizer=None, device=None,
|
623 |
base_model=None, tokenizer_base_model=None, lora_weights=None,
|
624 |
inference_server=None, prompt_type=None, prompt_dict=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
625 |
|
626 |
if cli:
|
627 |
from cli import run_cli
|
@@ -1280,6 +1331,7 @@ def get_score_model(score_model: str = None,
|
|
1280 |
def evaluate(
|
1281 |
model_state,
|
1282 |
my_db_state,
|
|
|
1283 |
# START NOTE: Examples must have same order of parameters
|
1284 |
instruction,
|
1285 |
iinput,
|
@@ -1302,6 +1354,7 @@ def evaluate(
|
|
1302 |
instruction_nochat,
|
1303 |
iinput_nochat,
|
1304 |
langchain_mode,
|
|
|
1305 |
langchain_action,
|
1306 |
langchain_agents,
|
1307 |
top_k_docs,
|
@@ -1317,6 +1370,9 @@ def evaluate(
|
|
1317 |
save_dir=None,
|
1318 |
sanitize_bot_response=False,
|
1319 |
model_state0=None,
|
|
|
|
|
|
|
1320 |
memory_restriction_level=None,
|
1321 |
max_max_new_tokens=None,
|
1322 |
is_public=None,
|
@@ -1327,11 +1383,11 @@ def evaluate(
|
|
1327 |
use_llm_if_no_docs=False,
|
1328 |
load_db_if_exists=True,
|
1329 |
dbs=None,
|
1330 |
-
user_path=None,
|
1331 |
detect_user_path_changes_every_query=None,
|
1332 |
use_openai_embedding=None,
|
1333 |
use_openai_model=None,
|
1334 |
hf_embedding_model=None,
|
|
|
1335 |
db_type=None,
|
1336 |
n_jobs=None,
|
1337 |
first_para=None,
|
@@ -1360,6 +1416,16 @@ def evaluate(
|
|
1360 |
assert chunk_size is not None and isinstance(chunk_size, int)
|
1361 |
assert n_jobs is not None
|
1362 |
assert first_para is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1363 |
|
1364 |
if debug:
|
1365 |
locals_dict = locals().copy()
|
@@ -1481,18 +1547,22 @@ def evaluate(
|
|
1481 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
1482 |
assert len(
|
1483 |
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
|
1484 |
-
if
|
1485 |
-
|
1486 |
-
elif
|
1487 |
-
db1 =
|
|
|
|
|
|
|
|
|
1488 |
else:
|
1489 |
-
|
1490 |
-
do_langchain_path = langchain_mode not in [False, 'Disabled', '
|
1491 |
base_model in non_hf_types or \
|
1492 |
force_langchain_evaluate
|
1493 |
if do_langchain_path:
|
1494 |
outr = ""
|
1495 |
-
# use smaller
|
1496 |
from gpt_langchain import run_qa_db
|
1497 |
gen_hyper_langchain = dict(do_sample=do_sample,
|
1498 |
temperature=temperature,
|
@@ -1515,10 +1585,11 @@ def evaluate(
|
|
1515 |
prompter=prompter,
|
1516 |
use_llm_if_no_docs=use_llm_if_no_docs,
|
1517 |
load_db_if_exists=load_db_if_exists,
|
1518 |
-
db=
|
1519 |
-
|
1520 |
detect_user_path_changes_every_query=detect_user_path_changes_every_query,
|
1521 |
-
|
|
|
1522 |
use_openai_embedding=use_openai_embedding,
|
1523 |
use_openai_model=use_openai_model,
|
1524 |
hf_embedding_model=hf_embedding_model,
|
@@ -1676,6 +1747,7 @@ def evaluate(
|
|
1676 |
chat_client = False
|
1677 |
where_from = "gr_client"
|
1678 |
client_langchain_mode = 'Disabled'
|
|
|
1679 |
client_langchain_action = LangChainAction.QUERY.value
|
1680 |
client_langchain_agents = []
|
1681 |
gen_server_kwargs = dict(temperature=temperature,
|
@@ -1729,13 +1801,14 @@ def evaluate(
|
|
1729 |
instruction_nochat=gr_prompt if not chat_client else '',
|
1730 |
iinput_nochat=gr_iinput, # only for chat=False
|
1731 |
langchain_mode=client_langchain_mode,
|
|
|
1732 |
langchain_action=client_langchain_action,
|
1733 |
langchain_agents=client_langchain_agents,
|
1734 |
top_k_docs=top_k_docs,
|
1735 |
chunk=chunk,
|
1736 |
chunk_size=chunk_size,
|
1737 |
-
document_subset=
|
1738 |
-
document_choice=[],
|
1739 |
)
|
1740 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
1741 |
if not stream_output:
|
@@ -2029,7 +2102,7 @@ def evaluate(
|
|
2029 |
|
2030 |
|
2031 |
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
2032 |
-
state_names = ['model_state', 'my_db_state']
|
2033 |
inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
|
2034 |
|
2035 |
|
@@ -2312,8 +2385,8 @@ y = np.random.randint(0, 1, 100)
|
|
2312 |
|
2313 |
# move to correct position
|
2314 |
for example in examples:
|
2315 |
-
example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value, [],
|
2316 |
-
top_k_docs, chunk, chunk_size,
|
2317 |
]
|
2318 |
# adjust examples if non-chat mode
|
2319 |
if not chat:
|
@@ -2373,7 +2446,7 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
|
|
2373 |
truncation=True,
|
2374 |
max_length=max_length_tokenize).to(smodel.device)
|
2375 |
try:
|
2376 |
-
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
2377 |
except torch.cuda.OutOfMemoryError as e:
|
2378 |
print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
2379 |
del inputs
|
@@ -2458,12 +2531,15 @@ def get_minmax_top_k_docs(is_public):
|
|
2458 |
return min_top_k_docs, max_top_k_docs, label_top_k_docs
|
2459 |
|
2460 |
|
2461 |
-
def history_to_context(history, langchain_mode1,
|
|
|
|
|
2462 |
memory_restriction_level1, keep_sources_in_context1):
|
2463 |
"""
|
2464 |
consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
|
2465 |
:param history:
|
2466 |
:param langchain_mode1:
|
|
|
2467 |
:param prompt_type1:
|
2468 |
:param prompt_dict1:
|
2469 |
:param chat1:
|
@@ -2476,7 +2552,7 @@ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, cha
|
|
2476 |
_, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
|
2477 |
for_context=True, model_max_length=model_max_length1)
|
2478 |
context1 = ''
|
2479 |
-
if max_prompt_length is not None and
|
2480 |
context1 = ''
|
2481 |
# - 1 below because current instruction already in history from user()
|
2482 |
for histi in range(0, len(history) - 1):
|
@@ -2512,6 +2588,22 @@ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, cha
|
|
2512 |
return context1
|
2513 |
|
2514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2515 |
def entrypoint_main():
|
2516 |
"""
|
2517 |
Examples:
|
|
|
8 |
import os
|
9 |
import time
|
10 |
import traceback
|
|
|
11 |
import typing
|
12 |
import warnings
|
13 |
from datetime import datetime
|
|
|
27 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
28 |
|
29 |
from evaluate_params import eval_func_param_names, no_default_param_names
|
30 |
+
from enums import DocumentSubset, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
|
31 |
+
source_postfix, LangChainAction, LangChainAgent, DocumentChoice
|
32 |
from loaders import get_loaders
|
33 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
34 |
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
|
35 |
+
have_langchain, set_openai, load_collection_enum
|
36 |
|
37 |
start_faulthandler()
|
38 |
import_matplotlib()
|
|
|
49 |
from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
|
50 |
from stopping import get_stopping
|
51 |
|
|
|
|
|
52 |
langchain_actions = [x.value for x in list(LangChainAction)]
|
53 |
|
54 |
langchain_agents_list = [x.value for x in list(LangChainAgent)]
|
|
|
113 |
show_examples: bool = None,
|
114 |
verbose: bool = False,
|
115 |
h2ocolors: bool = True,
|
116 |
+
dark: bool = False, # light tends to be best
|
117 |
height: int = 600,
|
118 |
show_lora: bool = True,
|
119 |
login_mode_if_model0: bool = False,
|
|
|
145 |
langchain_action: str = LangChainAction.QUERY.value,
|
146 |
langchain_agents: list = [],
|
147 |
force_langchain_evaluate: bool = False,
|
148 |
+
langchain_modes: list = [x.value for x in list(LangChainMode)],
|
149 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
150 |
# WIP:
|
151 |
# visible_langchain_actions: list = langchain_actions.copy(),
|
152 |
visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
|
153 |
visible_langchain_agents: list = langchain_agents_list.copy(),
|
154 |
+
document_subset: str = DocumentSubset.Relevant.name,
|
155 |
+
document_choice: list = [DocumentChoice.ALL.value],
|
156 |
user_path: str = None,
|
157 |
+
langchain_mode_paths: dict = {'UserData': None},
|
158 |
detect_user_path_changes_every_query: bool = False,
|
159 |
use_llm_if_no_docs: bool = False,
|
160 |
load_db_if_exists: bool = True,
|
|
|
163 |
use_openai_embedding: bool = False,
|
164 |
use_openai_model: bool = False,
|
165 |
hf_embedding_model: str = None,
|
166 |
+
cut_distance: float = 1.64,
|
167 |
+
add_chat_history_to_context: bool = True,
|
168 |
allow_upload_to_user_data: bool = True,
|
169 |
+
reload_langchain_state: bool = True,
|
170 |
allow_upload_to_my_data: bool = True,
|
171 |
enable_url_upload: bool = True,
|
172 |
enable_text_upload: bool = True,
|
|
|
183 |
pre_load_caption_model: bool = False,
|
184 |
caption_gpu: bool = True,
|
185 |
enable_ocr: bool = False,
|
186 |
+
enable_pdf_ocr: str = 'auto',
|
187 |
):
|
188 |
"""
|
189 |
|
|
|
263 |
:param show_examples: whether to show clickable examples in gradio
|
264 |
:param verbose: whether to show verbose prints
|
265 |
:param h2ocolors: whether to use H2O.ai theme
|
266 |
+
:param dark: whether to use dark mode for UI by default (still controlled in UI)
|
267 |
:param height: height of chat window
|
268 |
:param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
|
269 |
:param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
|
|
|
292 |
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
293 |
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
294 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
295 |
+
None: auto mode, check if langchain package exists, at least do LLM if so, else Disabled
|
296 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
297 |
:param langchain_action: Mode langchain operations in on documents.
|
298 |
Query: Make query of document(s)
|
|
|
304 |
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
|
305 |
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
306 |
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
307 |
+
:param langchain_mode_paths: dict of langchain_mode keys and disk path values to use for source of documents
|
308 |
+
E.g. "{'UserData2': 'userpath2'}"
|
309 |
+
Can be None even if existing DB, to avoid new documents being added from that path, source links that are on disk still work.
|
310 |
+
If user_path is not None, that path is used for 'UserData' instead of the value in this dict
|
311 |
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
|
312 |
Expensive for large number of files, so not done by default. By default only detect changes during db loading.
|
313 |
+
:param langchain_modes: names of collections/dbs to potentially have
|
314 |
:param visible_langchain_modes: dbs to generate at launch to be ready for LLM
|
315 |
Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
|
316 |
But wiki_full is expensive and requires preparation
|
317 |
To allow scratch space only live in session, add 'MyData' to list
|
318 |
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
319 |
+
If have own user modes, need to add these here or add in UI.
|
320 |
+
A state file is stored in visible_langchain_modes.pkl containing last UI-selected values of:
|
321 |
+
langchain_modes, visible_langchain_modes, and langchain_mode_paths
|
322 |
+
Delete the file if you want to start fresh,
|
323 |
+
but in any case the user_path passed in CLI is used for UserData even if was None or different
|
324 |
:param visible_langchain_actions: Which actions to allow
|
325 |
:param visible_langchain_agents: Which agents to allow
|
326 |
:param document_subset: Default document choice when taking subset of collection
|
327 |
+
:param document_choice: Chosen document(s) by internal name, 'All' means use all docs
|
328 |
+
:param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData or custom
|
329 |
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
330 |
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
331 |
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
|
|
336 |
Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
|
337 |
Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
|
338 |
We support automatically changing of embeddings for chroma, with a backup of db made if this is done
|
339 |
+
:param cut_distance: Distance to cut off references with larger distances when showing references.
|
340 |
+
1.64 is good to avoid dropping references for all-MiniLM-L6-v2, but instructor-large will always show excessive references.
|
341 |
+
For all-MiniLM-L6-v2, a value of 1.5 can push out even more references, or a large value of 100 can avoid any loss of references.
|
342 |
+
:param add_chat_history_to_context: Include chat context when performing action
|
343 |
+
Not supported yet for openai_chat when using document collection instead of LLM
|
344 |
+
Also not supported when using CLI mode
|
345 |
+
:param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db (UserData or custom user dbs)
|
346 |
+
:param reload_langchain_state: Whether to reload visible_langchain_modes.pkl file that contains any new user collections.
|
347 |
:param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
|
348 |
:param enable_url_upload: Whether to allow upload from URL
|
349 |
:param enable_text_upload: Whether to allow upload of text
|
350 |
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
|
351 |
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
352 |
+
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so needs to be in context length
|
353 |
:param top_k_docs: number of chunks to give LLM
|
354 |
:param reverse_docs: whether to reverse docs order so most relevant is closest to question.
|
355 |
Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
|
|
|
369 |
Recommended if using larger caption model
|
370 |
:param caption_gpu: If support caption, then use GPU if exists
|
371 |
:param enable_ocr: Whether to support OCR on images
|
372 |
+
:param enable_pdf_ocr: 'auto' means only use OCR if normal text extraction fails. Useful for pure image-based PDFs with text
|
373 |
+
'on' means always do OCR as additional parsing of same documents
|
374 |
+
'off' means don't do OCR (e.g. because it's slow even if 'auto' only would trigger if nothing else worked)
|
375 |
:return:
|
376 |
"""
|
377 |
if base_model is None:
|
|
|
433 |
if langchain_mode is not None:
|
434 |
visible_langchain_modes += [langchain_mode]
|
435 |
|
436 |
+
# update
|
437 |
+
if isinstance(langchain_mode_paths, str):
|
438 |
+
langchain_mode_paths = ast.literal_eval(langchain_mode_paths)
|
439 |
+
assert isinstance(langchain_mode_paths, dict)
|
440 |
+
if user_path:
|
441 |
+
langchain_mode_paths['UserData'] = user_path
|
442 |
+
makedirs(user_path)
|
443 |
+
|
444 |
+
if is_public:
|
445 |
+
allow_upload_to_user_data = False
|
446 |
+
if LangChainMode.USER_DATA.value in visible_langchain_modes:
|
447 |
+
visible_langchain_modes.remove(LangChainMode.USER_DATA.value)
|
448 |
+
|
449 |
+
# in-place, for non-scratch dbs
|
450 |
+
if allow_upload_to_user_data:
|
451 |
+
update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '')
|
452 |
+
# always listen to CLI-passed user_path if passed
|
453 |
+
if user_path:
|
454 |
+
langchain_mode_paths['UserData'] = user_path
|
455 |
+
|
456 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
457 |
assert len(
|
458 |
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
|
|
|
466 |
# auto-set langchain_mode
|
467 |
if have_langchain and langchain_mode is None:
|
468 |
# start in chat mode, in case just want to chat and don't want to get "No documents to query" by default.
|
469 |
+
langchain_mode = LangChainMode.LLM.value
|
470 |
+
if allow_upload_to_user_data and not is_public and langchain_mode_paths['UserData']:
|
471 |
print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True)
|
472 |
elif allow_upload_to_my_data:
|
473 |
print("Auto set langchain_mode=%s. Could use MyData instead."
|
474 |
" To allow UserData to pull files from disk,"
|
475 |
+
" set user_path or langchain_mode_paths, and ensure allow_upload_to_user_data=True" % langchain_mode,
|
476 |
+
flush=True)
|
477 |
else:
|
478 |
raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
|
479 |
+
if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value]:
|
|
|
480 |
raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
|
481 |
if langchain_mode is None:
|
482 |
# if not set yet, disable
|
483 |
langchain_mode = LangChainMode.DISABLED.value
|
484 |
+
print("Auto set langchain_mode=%s Have langchain package: %s" % (langchain_mode, have_langchain), flush=True)
|
485 |
|
486 |
if is_public:
|
487 |
allow_upload_to_user_data = False
|
|
|
592 |
|
593 |
if offload_folder:
|
594 |
makedirs(offload_folder)
|
|
|
|
|
595 |
|
596 |
placeholder_instruction, placeholder_input, \
|
597 |
stream_output, show_examples, \
|
|
|
617 |
verbose,
|
618 |
)
|
619 |
|
620 |
+
git_hash = get_githash() if is_public or os.getenv('GET_GITHASH') else "GET_GITHASH"
|
621 |
locals_dict = locals()
|
622 |
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
623 |
if verbose:
|
|
|
631 |
get_some_dbs_from_hf()
|
632 |
dbs = {}
|
633 |
for langchain_mode1 in visible_langchain_modes:
|
634 |
+
if langchain_mode1 in ['MyData']: # FIXME: Remove other custom temp dbs
|
635 |
# don't use what is on disk, remove it instead
|
636 |
for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
|
637 |
if os.path.isdir(gpath1):
|
|
|
646 |
db = prep_langchain(persist_directory1,
|
647 |
load_db_if_exists,
|
648 |
db_type, use_openai_embedding,
|
649 |
+
langchain_mode1, langchain_mode_paths,
|
650 |
hf_embedding_model,
|
651 |
kwargs_make_db=locals())
|
652 |
finally:
|
|
|
665 |
model_state_none = dict(model=None, tokenizer=None, device=None,
|
666 |
base_model=None, tokenizer_base_model=None, lora_weights=None,
|
667 |
inference_server=None, prompt_type=None, prompt_dict=None)
|
668 |
+
my_db_state0 = {LangChainMode.MY_DATA.value: [None, None]}
|
669 |
+
selection_docs_state0 = dict(visible_langchain_modes=visible_langchain_modes,
|
670 |
+
langchain_mode_paths=langchain_mode_paths,
|
671 |
+
langchain_modes=langchain_modes)
|
672 |
+
selection_docs_state = selection_docs_state0
|
673 |
+
langchain_modes0 = langchain_modes
|
674 |
+
langchain_mode_paths0 = langchain_mode_paths
|
675 |
+
visible_langchain_modes0 = visible_langchain_modes
|
676 |
|
677 |
if cli:
|
678 |
from cli import run_cli
|
|
|
1331 |
def evaluate(
|
1332 |
model_state,
|
1333 |
my_db_state,
|
1334 |
+
selection_docs_state,
|
1335 |
# START NOTE: Examples must have same order of parameters
|
1336 |
instruction,
|
1337 |
iinput,
|
|
|
1354 |
instruction_nochat,
|
1355 |
iinput_nochat,
|
1356 |
langchain_mode,
|
1357 |
+
add_chat_history_to_context,
|
1358 |
langchain_action,
|
1359 |
langchain_agents,
|
1360 |
top_k_docs,
|
|
|
1370 |
save_dir=None,
|
1371 |
sanitize_bot_response=False,
|
1372 |
model_state0=None,
|
1373 |
+
langchain_modes0=None,
|
1374 |
+
langchain_mode_paths0=None,
|
1375 |
+
visible_langchain_modes0=None,
|
1376 |
memory_restriction_level=None,
|
1377 |
max_max_new_tokens=None,
|
1378 |
is_public=None,
|
|
|
1383 |
use_llm_if_no_docs=False,
|
1384 |
load_db_if_exists=True,
|
1385 |
dbs=None,
|
|
|
1386 |
detect_user_path_changes_every_query=None,
|
1387 |
use_openai_embedding=None,
|
1388 |
use_openai_model=None,
|
1389 |
hf_embedding_model=None,
|
1390 |
+
cut_distance=None,
|
1391 |
db_type=None,
|
1392 |
n_jobs=None,
|
1393 |
first_para=None,
|
|
|
1416 |
assert chunk_size is not None and isinstance(chunk_size, int)
|
1417 |
assert n_jobs is not None
|
1418 |
assert first_para is not None
|
1419 |
+
assert isinstance(add_chat_history_to_context, bool)
|
1420 |
+
|
1421 |
+
if selection_docs_state is not None:
|
1422 |
+
langchain_modes = selection_docs_state.get('langchain_modes', langchain_modes0)
|
1423 |
+
langchain_mode_paths = selection_docs_state.get('langchain_mode_paths', langchain_mode_paths0)
|
1424 |
+
visible_langchain_modes = selection_docs_state.get('visible_langchain_modes', visible_langchain_modes0)
|
1425 |
+
else:
|
1426 |
+
langchain_modes = langchain_modes0
|
1427 |
+
langchain_mode_paths = langchain_mode_paths0
|
1428 |
+
visible_langchain_modes = visible_langchain_modes0
|
1429 |
|
1430 |
if debug:
|
1431 |
locals_dict = locals().copy()
|
|
|
1547 |
assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
|
1548 |
assert len(
|
1549 |
set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
|
1550 |
+
if dbs is not None and langchain_mode in dbs:
|
1551 |
+
db = dbs[langchain_mode]
|
1552 |
+
elif my_db_state is not None and langchain_mode in my_db_state:
|
1553 |
+
db1 = my_db_state[langchain_mode]
|
1554 |
+
if db1 is not None and len(db1) == 2:
|
1555 |
+
db = db1[0]
|
1556 |
+
else:
|
1557 |
+
db = None
|
1558 |
else:
|
1559 |
+
db = None
|
1560 |
+
do_langchain_path = langchain_mode not in [False, 'Disabled', 'LLM'] or \
|
1561 |
base_model in non_hf_types or \
|
1562 |
force_langchain_evaluate
|
1563 |
if do_langchain_path:
|
1564 |
outr = ""
|
1565 |
+
# use smaller cut_distance for wiki_full since so many matches could be obtained, and often irrelevant unless close
|
1566 |
from gpt_langchain import run_qa_db
|
1567 |
gen_hyper_langchain = dict(do_sample=do_sample,
|
1568 |
temperature=temperature,
|
|
|
1585 |
prompter=prompter,
|
1586 |
use_llm_if_no_docs=use_llm_if_no_docs,
|
1587 |
load_db_if_exists=load_db_if_exists,
|
1588 |
+
db=db,
|
1589 |
+
langchain_mode_paths=langchain_mode_paths,
|
1590 |
detect_user_path_changes_every_query=detect_user_path_changes_every_query,
|
1591 |
+
cut_distance=1.1 if langchain_mode in ['wiki_full'] else cut_distance,
|
1592 |
+
add_chat_history_to_context=add_chat_history_to_context,
|
1593 |
use_openai_embedding=use_openai_embedding,
|
1594 |
use_openai_model=use_openai_model,
|
1595 |
hf_embedding_model=hf_embedding_model,
|
|
|
1747 |
chat_client = False
|
1748 |
where_from = "gr_client"
|
1749 |
client_langchain_mode = 'Disabled'
|
1750 |
+
client_add_chat_history_to_context = True
|
1751 |
client_langchain_action = LangChainAction.QUERY.value
|
1752 |
client_langchain_agents = []
|
1753 |
gen_server_kwargs = dict(temperature=temperature,
|
|
|
1801 |
instruction_nochat=gr_prompt if not chat_client else '',
|
1802 |
iinput_nochat=gr_iinput, # only for chat=False
|
1803 |
langchain_mode=client_langchain_mode,
|
1804 |
+
add_chat_history_to_context=client_add_chat_history_to_context,
|
1805 |
langchain_action=client_langchain_action,
|
1806 |
langchain_agents=client_langchain_agents,
|
1807 |
top_k_docs=top_k_docs,
|
1808 |
chunk=chunk,
|
1809 |
chunk_size=chunk_size,
|
1810 |
+
document_subset=DocumentSubset.Relevant.name,
|
1811 |
+
document_choice=[DocumentChoice.ALL.value],
|
1812 |
)
|
1813 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
1814 |
if not stream_output:
|
|
|
2102 |
|
2103 |
|
2104 |
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
2105 |
+
state_names = ['model_state', 'my_db_state', 'selection_docs_state']
|
2106 |
inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
|
2107 |
|
2108 |
|
|
|
2385 |
|
2386 |
# move to correct position
|
2387 |
for example in examples:
|
2388 |
+
example += [chat, '', '', LangChainMode.DISABLED.value, True, LangChainAction.QUERY.value, [],
|
2389 |
+
top_k_docs, chunk, chunk_size, DocumentSubset.Relevant.name, []
|
2390 |
]
|
2391 |
# adjust examples if non-chat mode
|
2392 |
if not chat:
|
|
|
2446 |
truncation=True,
|
2447 |
max_length=max_length_tokenize).to(smodel.device)
|
2448 |
try:
|
2449 |
+
score = torch.sigmoid(smodel(**inputs.to(smodel.device)).logits[0].float()).cpu().detach().numpy()[0]
|
2450 |
except torch.cuda.OutOfMemoryError as e:
|
2451 |
print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
2452 |
del inputs
|
|
|
2531 |
return min_top_k_docs, max_top_k_docs, label_top_k_docs
|
2532 |
|
2533 |
|
2534 |
+
def history_to_context(history, langchain_mode1,
|
2535 |
+
add_chat_history_to_context,
|
2536 |
+
prompt_type1, prompt_dict1, chat1, model_max_length1,
|
2537 |
memory_restriction_level1, keep_sources_in_context1):
|
2538 |
"""
|
2539 |
consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
|
2540 |
:param history:
|
2541 |
:param langchain_mode1:
|
2542 |
+
:param add_chat_history_to_context:
|
2543 |
:param prompt_type1:
|
2544 |
:param prompt_dict1:
|
2545 |
:param chat1:
|
|
|
2552 |
_, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
|
2553 |
for_context=True, model_max_length=model_max_length1)
|
2554 |
context1 = ''
|
2555 |
+
if max_prompt_length is not None and add_chat_history_to_context:
|
2556 |
context1 = ''
|
2557 |
# - 1 below because current instruction already in history from user()
|
2558 |
for histi in range(0, len(history) - 1):
|
|
|
2588 |
return context1
|
2589 |
|
2590 |
|
2591 |
+
def update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, extra):
|
2592 |
+
# update from saved state on disk
|
2593 |
+
langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = \
|
2594 |
+
load_collection_enum(extra)
|
2595 |
+
|
2596 |
+
visible_langchain_modes_temp = visible_langchain_modes.copy() + visible_langchain_modes_from_file
|
2597 |
+
visible_langchain_modes.clear() # don't lose original reference
|
2598 |
+
[visible_langchain_modes.append(x) for x in visible_langchain_modes_temp if x not in visible_langchain_modes]
|
2599 |
+
|
2600 |
+
langchain_mode_paths.update(langchain_mode_paths_from_file)
|
2601 |
+
|
2602 |
+
langchain_modes_temp = langchain_modes.copy() + langchain_modes_from_file
|
2603 |
+
langchain_modes.clear() # don't lose original reference
|
2604 |
+
[langchain_modes.append(x) for x in langchain_modes_temp if x not in langchain_modes]
|
2605 |
+
|
2606 |
+
|
2607 |
def entrypoint_main():
|
2608 |
"""
|
2609 |
Examples:
|
gpt4all_llm.py
CHANGED
@@ -95,15 +95,17 @@ def get_llm_gpt4all(model_name,
|
|
95 |
streaming=False,
|
96 |
callbacks=None,
|
97 |
prompter=None,
|
|
|
|
|
98 |
verbose=False,
|
99 |
):
|
100 |
assert prompter is not None
|
101 |
env_gpt4all_file = ".env_gpt4all"
|
102 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
103 |
-
|
104 |
default_kwargs = dict(context_erase=0.5,
|
105 |
n_batch=1,
|
106 |
-
|
107 |
n_predict=max_new_tokens,
|
108 |
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
109 |
repeat_penalty=repetition_penalty,
|
@@ -117,7 +119,8 @@ def get_llm_gpt4all(model_name,
|
|
117 |
cls = H2OLlamaCpp
|
118 |
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
119 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
120 |
-
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
|
|
|
121 |
llm = cls(**model_kwargs)
|
122 |
llm.client.verbose = verbose
|
123 |
elif model_name == 'gpt4all_llama':
|
@@ -125,14 +128,16 @@ def get_llm_gpt4all(model_name,
|
|
125 |
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
126 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
127 |
model_kwargs.update(
|
128 |
-
dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
|
|
|
129 |
llm = cls(**model_kwargs)
|
130 |
elif model_name == 'gptj':
|
131 |
cls = H2OGPT4All
|
132 |
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
133 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
134 |
model_kwargs.update(
|
135 |
-
dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
|
|
|
136 |
llm = cls(**model_kwargs)
|
137 |
else:
|
138 |
raise RuntimeError("No such model_name %s" % model_name)
|
@@ -142,6 +147,8 @@ def get_llm_gpt4all(model_name,
|
|
142 |
class H2OGPT4All(gpt4all.GPT4All):
|
143 |
model: Any
|
144 |
prompter: Any
|
|
|
|
|
145 |
"""Path to the pre-trained GPT4All model file."""
|
146 |
|
147 |
@root_validator()
|
@@ -187,10 +194,11 @@ class H2OGPT4All(gpt4all.GPT4All):
|
|
187 |
**kwargs,
|
188 |
) -> str:
|
189 |
# Roughly 4 chars per token if natural language
|
190 |
-
|
|
|
191 |
|
192 |
# use instruct prompting
|
193 |
-
data_point = dict(context=
|
194 |
prompt = self.prompter.generate_prompt(data_point)
|
195 |
|
196 |
verbose = False
|
@@ -206,6 +214,8 @@ from langchain.llms import LlamaCpp
|
|
206 |
class H2OLlamaCpp(LlamaCpp):
|
207 |
model_path: Any
|
208 |
prompter: Any
|
|
|
|
|
209 |
"""Path to the pre-trained GPT4All model file."""
|
210 |
|
211 |
@root_validator()
|
@@ -276,7 +286,7 @@ class H2OLlamaCpp(LlamaCpp):
|
|
276 |
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
277 |
|
278 |
# use instruct prompting
|
279 |
-
data_point = dict(context=
|
280 |
prompt = self.prompter.generate_prompt(data_point)
|
281 |
|
282 |
if verbose:
|
|
|
95 |
streaming=False,
|
96 |
callbacks=None,
|
97 |
prompter=None,
|
98 |
+
context='',
|
99 |
+
iinput='',
|
100 |
verbose=False,
|
101 |
):
|
102 |
assert prompter is not None
|
103 |
env_gpt4all_file = ".env_gpt4all"
|
104 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
105 |
+
max_tokens = env_kwargs.pop('max_tokens', 2048 - max_new_tokens)
|
106 |
default_kwargs = dict(context_erase=0.5,
|
107 |
n_batch=1,
|
108 |
+
max_tokens=max_tokens,
|
109 |
n_predict=max_new_tokens,
|
110 |
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
111 |
repeat_penalty=repetition_penalty,
|
|
|
119 |
cls = H2OLlamaCpp
|
120 |
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
121 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
122 |
+
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
|
123 |
+
prompter=prompter, context=context, iinput=iinput))
|
124 |
llm = cls(**model_kwargs)
|
125 |
llm.client.verbose = verbose
|
126 |
elif model_name == 'gpt4all_llama':
|
|
|
128 |
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
129 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
130 |
model_kwargs.update(
|
131 |
+
dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
|
132 |
+
prompter=prompter, context=context, iinput=iinput))
|
133 |
llm = cls(**model_kwargs)
|
134 |
elif model_name == 'gptj':
|
135 |
cls = H2OGPT4All
|
136 |
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
137 |
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
138 |
model_kwargs.update(
|
139 |
+
dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
|
140 |
+
prompter=prompter, context=context, iinput=iinput))
|
141 |
llm = cls(**model_kwargs)
|
142 |
else:
|
143 |
raise RuntimeError("No such model_name %s" % model_name)
|
|
|
147 |
class H2OGPT4All(gpt4all.GPT4All):
|
148 |
model: Any
|
149 |
prompter: Any
|
150 |
+
context: Any = ''
|
151 |
+
iinput: Any = ''
|
152 |
"""Path to the pre-trained GPT4All model file."""
|
153 |
|
154 |
@root_validator()
|
|
|
194 |
**kwargs,
|
195 |
) -> str:
|
196 |
# Roughly 4 chars per token if natural language
|
197 |
+
n_ctx = 2048
|
198 |
+
prompt = prompt[-self.max_tokens * 4:]
|
199 |
|
200 |
# use instruct prompting
|
201 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
202 |
prompt = self.prompter.generate_prompt(data_point)
|
203 |
|
204 |
verbose = False
|
|
|
214 |
class H2OLlamaCpp(LlamaCpp):
|
215 |
model_path: Any
|
216 |
prompter: Any
|
217 |
+
context: Any
|
218 |
+
iinput: Any
|
219 |
"""Path to the pre-trained GPT4All model file."""
|
220 |
|
221 |
@root_validator()
|
|
|
286 |
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
287 |
|
288 |
# use instruct prompting
|
289 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
290 |
prompt = self.prompter.generate_prompt(data_point)
|
291 |
|
292 |
if verbose:
|
gpt_langchain.py
CHANGED
@@ -24,8 +24,8 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
|
|
24 |
from langchain.schema import LLMResult
|
25 |
from tqdm import tqdm
|
26 |
|
27 |
-
from enums import
|
28 |
-
LangChainAction, LangChainMode
|
29 |
from evaluate_params import gen_hyper
|
30 |
from gen import get_model, SEED
|
31 |
from prompter import non_hf_types, PromptType, Prompter
|
@@ -96,11 +96,15 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss',
|
|
96 |
db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
97 |
hf_embedding_model, verbose=False)
|
98 |
if db is None:
|
|
|
|
|
|
|
|
|
99 |
db = Chroma.from_documents(documents=sources,
|
100 |
embedding=embedding,
|
101 |
persist_directory=persist_directory,
|
102 |
collection_name=collection_name,
|
103 |
-
|
104 |
db.persist()
|
105 |
clear_embedding(db)
|
106 |
save_embed(db, use_openai_embedding, hf_embedding_model)
|
@@ -305,6 +309,8 @@ class GradioInference(LLM):
|
|
305 |
sanitize_bot_response: bool = False
|
306 |
|
307 |
prompter: Any = None
|
|
|
|
|
308 |
client: Any = None
|
309 |
|
310 |
class Config:
|
@@ -348,14 +354,15 @@ class GradioInference(LLM):
|
|
348 |
stream_output = self.stream
|
349 |
gr_client = self.client
|
350 |
client_langchain_mode = 'Disabled'
|
|
|
351 |
client_langchain_action = LangChainAction.QUERY.value
|
352 |
client_langchain_agents = []
|
353 |
top_k_docs = 1
|
354 |
chunk = True
|
355 |
chunk_size = 512
|
356 |
client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
|
357 |
-
iinput='', # only for chat=True
|
358 |
-
context=
|
359 |
# streaming output is supported, loops over and outputs each generation in streaming mode
|
360 |
# but leave stream_output=False for simple input/output mode
|
361 |
stream_output=stream_output,
|
@@ -376,15 +383,16 @@ class GradioInference(LLM):
|
|
376 |
chat=self.chat_client,
|
377 |
|
378 |
instruction_nochat=prompt if not self.chat_client else '',
|
379 |
-
iinput_nochat='',
|
380 |
langchain_mode=client_langchain_mode,
|
|
|
381 |
langchain_action=client_langchain_action,
|
382 |
langchain_agents=client_langchain_agents,
|
383 |
top_k_docs=top_k_docs,
|
384 |
chunk=chunk,
|
385 |
chunk_size=chunk_size,
|
386 |
-
document_subset=
|
387 |
-
document_choice=[],
|
388 |
)
|
389 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
390 |
if not stream_output:
|
@@ -454,6 +462,8 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
|
454 |
stream: bool = False
|
455 |
sanitize_bot_response: bool = False
|
456 |
prompter: Any = None
|
|
|
|
|
457 |
tokenizer: Any = None
|
458 |
client: Any = None
|
459 |
|
@@ -495,7 +505,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
|
495 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
496 |
|
497 |
# NOTE: TGI server does not add prompting, so must do here
|
498 |
-
data_point = dict(context=
|
499 |
prompt = self.prompter.generate_prompt(data_point)
|
500 |
|
501 |
gen_server_kwargs = dict(do_sample=self.do_sample,
|
@@ -574,6 +584,8 @@ class H2OOpenAI(OpenAI):
|
|
574 |
stop_sequences: Any = None
|
575 |
sanitize_bot_response: bool = False
|
576 |
prompter: Any = None
|
|
|
|
|
577 |
tokenizer: Any = None
|
578 |
|
579 |
@classmethod
|
@@ -599,7 +611,7 @@ class H2OOpenAI(OpenAI):
|
|
599 |
for prompti, prompt in enumerate(prompts):
|
600 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
601 |
# NOTE: OpenAI/vLLM server does not add prompting, so must do here
|
602 |
-
data_point = dict(context=
|
603 |
prompt = self.prompter.generate_prompt(data_point)
|
604 |
prompts[prompti] = prompt
|
605 |
|
@@ -677,17 +689,22 @@ def get_llm(use_openai_model=False,
|
|
677 |
prompt_type=None,
|
678 |
prompt_dict=None,
|
679 |
prompter=None,
|
|
|
|
|
680 |
sanitize_bot_response=False,
|
681 |
verbose=False,
|
682 |
):
|
|
|
|
|
683 |
if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
|
684 |
if use_openai_model and model_name is None:
|
685 |
model_name = "gpt-3.5-turbo"
|
686 |
-
|
687 |
-
|
688 |
kwargs_extra = {}
|
689 |
if inference_server == 'openai_chat' or inf_type == 'vllm_chat':
|
690 |
cls = H2OChatOpenAI
|
|
|
691 |
else:
|
692 |
cls = H2OOpenAI
|
693 |
if inf_type == 'vllm':
|
@@ -697,6 +714,8 @@ def get_llm(use_openai_model=False,
|
|
697 |
kwargs_extra = dict(stop_sequences=stop_sequences,
|
698 |
sanitize_bot_response=sanitize_bot_response,
|
699 |
prompter=prompter,
|
|
|
|
|
700 |
tokenizer=tokenizer,
|
701 |
client=None)
|
702 |
|
@@ -711,7 +730,7 @@ def get_llm(use_openai_model=False,
|
|
711 |
callbacks=callbacks if stream_output else None,
|
712 |
openai_api_key=openai.api_key,
|
713 |
openai_api_base=openai.api_base,
|
714 |
-
logit_bias=None if inf_type =='vllm' else {},
|
715 |
max_retries=2,
|
716 |
streaming=stream_output,
|
717 |
**kwargs_extra
|
@@ -769,6 +788,8 @@ def get_llm(use_openai_model=False,
|
|
769 |
callbacks=callbacks if stream_output else None,
|
770 |
stream=stream_output,
|
771 |
prompter=prompter,
|
|
|
|
|
772 |
client=gr_client,
|
773 |
sanitize_bot_response=sanitize_bot_response,
|
774 |
)
|
@@ -789,6 +810,8 @@ def get_llm(use_openai_model=False,
|
|
789 |
callbacks=callbacks if stream_output else None,
|
790 |
stream=stream_output,
|
791 |
prompter=prompter,
|
|
|
|
|
792 |
tokenizer=tokenizer,
|
793 |
client=hf_client,
|
794 |
timeout=max_time,
|
@@ -821,6 +844,8 @@ def get_llm(use_openai_model=False,
|
|
821 |
verbose=verbose,
|
822 |
streaming=stream_output,
|
823 |
prompter=prompter,
|
|
|
|
|
824 |
)
|
825 |
else:
|
826 |
if model is None:
|
@@ -863,6 +888,8 @@ def get_llm(use_openai_model=False,
|
|
863 |
from h2oai_pipeline import H2OTextGenerationPipeline
|
864 |
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
|
865 |
prompter=prompter,
|
|
|
|
|
866 |
prompt_type=prompt_type,
|
867 |
prompt_dict=prompt_dict,
|
868 |
sanitize_bot_response=sanitize_bot_response,
|
@@ -1048,7 +1075,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1048 |
is_url=False, is_txt=False,
|
1049 |
enable_captions=True,
|
1050 |
captions_model=None,
|
1051 |
-
enable_ocr=False, caption_loader=None,
|
1052 |
headsize=50):
|
1053 |
if file is None:
|
1054 |
if fail_any_exception:
|
@@ -1065,6 +1092,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1065 |
base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
|
1066 |
base_path = os.path.join(dir_name, base_name)
|
1067 |
if is_url:
|
|
|
1068 |
if file.lower().startswith('arxiv:'):
|
1069 |
query = file.lower().split('arxiv:')
|
1070 |
if len(query) == 2 and have_arxiv:
|
@@ -1216,21 +1244,54 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
1216 |
from dotenv import dotenv_values
|
1217 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
1218 |
pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
|
|
|
|
|
1219 |
if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
|
1220 |
# GPL, only use if installed
|
1221 |
from langchain.document_loaders import PyMuPDFLoader
|
1222 |
# load() still chunks by pages, but every page has title at start to help
|
1223 |
doc1 = PyMuPDFLoader(file).load()
|
|
|
|
|
|
|
1224 |
doc1 = clean_doc(doc1)
|
1225 |
-
|
1226 |
doc1 = UnstructuredPDFLoader(file).load()
|
|
|
|
|
|
|
1227 |
# seems to not need cleaning in most cases
|
1228 |
-
|
1229 |
# open-source fallback
|
1230 |
# load() still chunks by pages, but every page has title at start to help
|
1231 |
doc1 = PyPDFLoader(file).load()
|
|
|
|
|
|
|
1232 |
doc1 = clean_doc(doc1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1233 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
1234 |
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
|
1235 |
add_meta(doc1, file)
|
1236 |
elif file.lower().endswith('.csv'):
|
@@ -1283,7 +1344,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
|
|
1283 |
is_url=False, is_txt=False,
|
1284 |
enable_captions=True,
|
1285 |
captions_model=None,
|
1286 |
-
enable_ocr=False, caption_loader=None):
|
1287 |
if verbose:
|
1288 |
if is_url:
|
1289 |
print("Ingesting URL: %s" % file, flush=True)
|
@@ -1301,6 +1362,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
|
|
1301 |
enable_captions=enable_captions,
|
1302 |
captions_model=captions_model,
|
1303 |
enable_ocr=enable_ocr,
|
|
|
1304 |
caption_loader=caption_loader)
|
1305 |
except BaseException as e:
|
1306 |
print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
|
@@ -1309,7 +1371,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
|
|
1309 |
else:
|
1310 |
exception_doc = Document(
|
1311 |
page_content='',
|
1312 |
-
metadata={"source": file, "exception": '%s
|
1313 |
"traceback": traceback.format_exc()})
|
1314 |
res = [exception_doc]
|
1315 |
if return_file:
|
@@ -1330,6 +1392,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1330 |
captions_model=None,
|
1331 |
caption_loader=None,
|
1332 |
enable_ocr=False,
|
|
|
1333 |
existing_files=[],
|
1334 |
existing_hash_ids={},
|
1335 |
):
|
@@ -1351,11 +1414,15 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1351 |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
1352 |
for ftype in non_image_types]
|
1353 |
else:
|
1354 |
-
if isinstance(path_or_paths, str)
|
1355 |
-
path_or_paths
|
|
|
|
|
|
|
|
|
1356 |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
1357 |
-
assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)),
|
1358 |
-
path_or_paths)
|
1359 |
# reform out of allowed types
|
1360 |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
1361 |
# could do below:
|
@@ -1407,6 +1474,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1407 |
captions_model=captions_model,
|
1408 |
caption_loader=caption_loader,
|
1409 |
enable_ocr=enable_ocr,
|
|
|
1410 |
)
|
1411 |
|
1412 |
if n_jobs != 1 and len(globs_non_image_types) > 1:
|
@@ -1439,7 +1507,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1439 |
with open(fil, 'rb') as f:
|
1440 |
documents.extend(pickle.load(f))
|
1441 |
# remove temp pickle
|
1442 |
-
|
1443 |
else:
|
1444 |
documents = reduce(concat, documents)
|
1445 |
return documents
|
@@ -1447,7 +1515,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
1447 |
|
1448 |
def prep_langchain(persist_directory,
|
1449 |
load_db_if_exists,
|
1450 |
-
db_type, use_openai_embedding, langchain_mode,
|
1451 |
hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
|
1452 |
"""
|
1453 |
do prep first time, involving downloads
|
@@ -1457,6 +1525,7 @@ def prep_langchain(persist_directory,
|
|
1457 |
assert langchain_mode not in ['MyData'], "Should not prep scratch data"
|
1458 |
|
1459 |
db_dir_exists = os.path.isdir(persist_directory)
|
|
|
1460 |
|
1461 |
if db_dir_exists and user_path is None:
|
1462 |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
@@ -1592,7 +1661,7 @@ def make_db(**langchain_kwargs):
|
|
1592 |
langchain_kwargs[k] = defaults_db[k]
|
1593 |
# final check for missing
|
1594 |
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
|
1595 |
-
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
1596 |
# only keep actual used
|
1597 |
langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
|
1598 |
return _make_db(**langchain_kwargs)
|
@@ -1626,13 +1695,14 @@ def _make_db(use_openai_embedding=False,
|
|
1626 |
first_para=False, text_limit=None,
|
1627 |
chunk=True, chunk_size=512,
|
1628 |
langchain_mode=None,
|
1629 |
-
|
1630 |
db_type='faiss',
|
1631 |
load_db_if_exists=True,
|
1632 |
db=None,
|
1633 |
n_jobs=-1,
|
1634 |
verbose=False):
|
1635 |
persist_directory = get_persist_directory(langchain_mode)
|
|
|
1636 |
# see if can get persistent chroma db
|
1637 |
db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
1638 |
hf_embedding_model, verbose=verbose)
|
@@ -1640,23 +1710,8 @@ def _make_db(use_openai_embedding=False,
|
|
1640 |
db = db_trial
|
1641 |
|
1642 |
sources = []
|
1643 |
-
if not db
|
1644 |
-
|
1645 |
-
langchain_mode in ['UserData']:
|
1646 |
-
# Should not make MyData db this way, why avoided, only upload from UI
|
1647 |
-
assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
|
1648 |
-
if verbose:
|
1649 |
-
if langchain_mode in ['UserData']:
|
1650 |
-
if user_path is not None:
|
1651 |
-
print("Checking if changed or new sources in %s, and generating sources them" % user_path,
|
1652 |
-
flush=True)
|
1653 |
-
elif db is None:
|
1654 |
-
print("user_path not passed and no db, no sources", flush=True)
|
1655 |
-
else:
|
1656 |
-
print("user_path not passed, using only existing db, no new sources", flush=True)
|
1657 |
-
else:
|
1658 |
-
print("Generating %s sources" % langchain_mode, flush=True)
|
1659 |
-
if langchain_mode in ['wiki_full', 'All', "'All'"]:
|
1660 |
from read_wiki_full import get_all_documents
|
1661 |
small_test = None
|
1662 |
print("Generating new wiki", flush=True)
|
@@ -1666,55 +1721,48 @@ def _make_db(use_openai_embedding=False,
|
|
1666 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1667 |
print("Chunked new wiki", flush=True)
|
1668 |
sources.extend(sources1)
|
1669 |
-
|
1670 |
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
|
1671 |
if chunk:
|
1672 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1673 |
sources.extend(sources1)
|
1674 |
-
|
1675 |
# sources = get_github_docs("dagster-io", "dagster")
|
1676 |
sources1 = get_github_docs("h2oai", "h2ogpt")
|
1677 |
# FIXME: always chunk for now
|
1678 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1679 |
sources.extend(sources1)
|
1680 |
-
|
1681 |
sources1 = get_dai_docs(from_hf=True)
|
1682 |
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
|
1683 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1684 |
sources.extend(sources1)
|
1685 |
-
|
1686 |
-
|
1687 |
-
|
1688 |
-
|
1689 |
-
|
1690 |
-
|
1691 |
-
|
1692 |
-
|
1693 |
-
|
1694 |
-
|
1695 |
-
|
1696 |
-
|
1697 |
-
|
1698 |
-
|
1699 |
-
|
1700 |
-
|
1701 |
-
|
1702 |
-
|
1703 |
-
|
1704 |
-
|
1705 |
-
|
1706 |
-
|
1707 |
-
|
1708 |
-
|
1709 |
-
|
1710 |
-
|
1711 |
-
# from langchain.document_loaders import UnstructuredURLLoader
|
1712 |
-
# loader = UnstructuredURLLoader(urls=urls)
|
1713 |
-
urls = ["https://www.birdsongsf.com/who-we-are/"]
|
1714 |
-
from langchain.document_loaders import PlaywrightURLLoader
|
1715 |
-
loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
|
1716 |
-
sources1 = loader.load()
|
1717 |
-
sources.extend(sources1)
|
1718 |
if not sources:
|
1719 |
if verbose:
|
1720 |
if db is not None:
|
@@ -1737,7 +1785,7 @@ def _make_db(use_openai_embedding=False,
|
|
1737 |
else:
|
1738 |
print("Did not generate db since no sources", flush=True)
|
1739 |
new_sources_metadata = [x.metadata for x in sources]
|
1740 |
-
elif user_path is not None
|
1741 |
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
|
1742 |
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
|
1743 |
use_openai_embedding=use_openai_embedding,
|
@@ -1835,7 +1883,7 @@ def run_qa_db(**kwargs):
|
|
1835 |
kwargs['answer_with_sources'] = True
|
1836 |
kwargs['show_rank'] = False
|
1837 |
missing_kwargs = [x for x in func_names if x not in kwargs]
|
1838 |
-
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
1839 |
# only keep actual used
|
1840 |
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
1841 |
try:
|
@@ -1849,7 +1897,7 @@ def _run_qa_db(query=None,
|
|
1849 |
context=None,
|
1850 |
use_openai_model=False, use_openai_embedding=False,
|
1851 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1852 |
-
|
1853 |
detect_user_path_changes_every_query=False,
|
1854 |
db_type='faiss',
|
1855 |
model_name=None, model=None, tokenizer=None, inference_server=None,
|
@@ -1859,7 +1907,8 @@ def _run_qa_db(query=None,
|
|
1859 |
prompt_type=None,
|
1860 |
prompt_dict=None,
|
1861 |
answer_with_sources=True,
|
1862 |
-
|
|
|
1863 |
sanitize_bot_response=False,
|
1864 |
show_rank=False,
|
1865 |
use_llm_if_no_docs=False,
|
@@ -1879,8 +1928,8 @@ def _run_qa_db(query=None,
|
|
1879 |
langchain_mode=None,
|
1880 |
langchain_action=None,
|
1881 |
langchain_agents=None,
|
1882 |
-
document_subset=
|
1883 |
-
document_choice=[],
|
1884 |
n_jobs=-1,
|
1885 |
verbose=False,
|
1886 |
cli=False,
|
@@ -1899,7 +1948,7 @@ def _run_qa_db(query=None,
|
|
1899 |
:param top_k_docs:
|
1900 |
:param chunk:
|
1901 |
:param chunk_size:
|
1902 |
-
:param
|
1903 |
:param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
|
1904 |
:param model_name: model name, used to switch behaviors
|
1905 |
:param model: pre-initialized model, else will make new one
|
@@ -1907,6 +1956,7 @@ def _run_qa_db(query=None,
|
|
1907 |
:param answer_with_sources
|
1908 |
:return:
|
1909 |
"""
|
|
|
1910 |
if model is not None:
|
1911 |
assert model_name is not None # require so can make decisions
|
1912 |
assert query is not None
|
@@ -1921,6 +1971,8 @@ def _run_qa_db(query=None,
|
|
1921 |
else:
|
1922 |
prompt_dict = ''
|
1923 |
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
|
|
|
|
|
1924 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1925 |
model=model,
|
1926 |
tokenizer=tokenizer,
|
@@ -1940,11 +1992,13 @@ def _run_qa_db(query=None,
|
|
1940 |
prompt_type=prompt_type,
|
1941 |
prompt_dict=prompt_dict,
|
1942 |
prompter=prompter,
|
|
|
|
|
1943 |
sanitize_bot_response=sanitize_bot_response,
|
1944 |
verbose=verbose,
|
1945 |
)
|
1946 |
|
1947 |
-
|
1948 |
scores = []
|
1949 |
chain = None
|
1950 |
|
@@ -1956,9 +2010,13 @@ def _run_qa_db(query=None,
|
|
1956 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1957 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1958 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1959 |
-
docs, chain, scores,
|
1960 |
if document_subset in non_query_commands:
|
1961 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
|
|
|
|
|
|
|
|
1962 |
yield formatted_doc_chunks, ''
|
1963 |
return
|
1964 |
if not use_llm_if_no_docs:
|
@@ -1970,7 +2028,6 @@ def _run_qa_db(query=None,
|
|
1970 |
yield ret, extra
|
1971 |
return
|
1972 |
if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
|
1973 |
-
LangChainMode.CHAT_LLM.value,
|
1974 |
LangChainMode.LLM.value]:
|
1975 |
ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
|
1976 |
extra = ''
|
@@ -2026,7 +2083,7 @@ def _run_qa_db(query=None,
|
|
2026 |
else:
|
2027 |
answer = chain()
|
2028 |
|
2029 |
-
if not
|
2030 |
ret = answer['output_text']
|
2031 |
extra = ''
|
2032 |
yield ret, extra
|
@@ -2038,9 +2095,10 @@ def _run_qa_db(query=None,
|
|
2038 |
|
2039 |
def get_chain(query=None,
|
2040 |
iinput=None,
|
|
|
2041 |
use_openai_model=False, use_openai_embedding=False,
|
2042 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
2043 |
-
|
2044 |
detect_user_path_changes_every_query=False,
|
2045 |
db_type='faiss',
|
2046 |
model_name=None,
|
@@ -2048,14 +2106,15 @@ def get_chain(query=None,
|
|
2048 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
2049 |
prompt_type=None,
|
2050 |
prompt_dict=None,
|
2051 |
-
|
|
|
2052 |
load_db_if_exists=False,
|
2053 |
db=None,
|
2054 |
langchain_mode=None,
|
2055 |
langchain_action=None,
|
2056 |
langchain_agents=None,
|
2057 |
-
document_subset=
|
2058 |
-
document_choice=[],
|
2059 |
n_jobs=-1,
|
2060 |
# beyond run_db_query:
|
2061 |
llm=None,
|
@@ -2070,12 +2129,12 @@ def get_chain(query=None,
|
|
2070 |
assert langchain_agents is not None # should be at least []
|
2071 |
# determine whether use of context out of docs is planned
|
2072 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2073 |
-
if langchain_mode in ['Disabled', '
|
2074 |
-
|
2075 |
else:
|
2076 |
-
|
2077 |
else:
|
2078 |
-
|
2079 |
|
2080 |
# https://github.com/hwchase17/langchain/issues/1946
|
2081 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
@@ -2092,14 +2151,17 @@ def get_chain(query=None,
|
|
2092 |
# avoid looking at user_path during similarity search db handling,
|
2093 |
# if already have db and not updating from user_path every query
|
2094 |
# but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
|
2095 |
-
|
|
|
|
|
|
|
2096 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
|
2097 |
hf_embedding_model=hf_embedding_model,
|
2098 |
first_para=first_para, text_limit=text_limit,
|
2099 |
chunk=chunk,
|
2100 |
chunk_size=chunk_size,
|
2101 |
langchain_mode=langchain_mode,
|
2102 |
-
|
2103 |
db_type=db_type,
|
2104 |
load_db_if_exists=load_db_if_exists,
|
2105 |
db=db,
|
@@ -2119,7 +2181,7 @@ def get_chain(query=None,
|
|
2119 |
else:
|
2120 |
extra = ""
|
2121 |
prefix = ""
|
2122 |
-
if langchain_mode in ['Disabled', '
|
2123 |
template_if_no_docs = template = """%s{context}{question}""" % prefix
|
2124 |
else:
|
2125 |
template = """%s
|
@@ -2160,7 +2222,7 @@ def get_chain(query=None,
|
|
2160 |
else:
|
2161 |
use_template = False
|
2162 |
|
2163 |
-
if db and
|
2164 |
base_path = 'locks'
|
2165 |
makedirs(base_path)
|
2166 |
if hasattr(db, '_persist_directory'):
|
@@ -2174,10 +2236,10 @@ def get_chain(query=None,
|
|
2174 |
filter_kwargs = {}
|
2175 |
else:
|
2176 |
assert document_choice is not None, "Document choice was None"
|
2177 |
-
if len(document_choice) >= 1 and document_choice[0] ==
|
2178 |
filter_kwargs = {}
|
2179 |
elif len(document_choice) >= 2:
|
2180 |
-
if document_choice[0] ==
|
2181 |
# remove 'All'
|
2182 |
document_choice = document_choice[1:]
|
2183 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
@@ -2189,10 +2251,10 @@ def get_chain(query=None,
|
|
2189 |
else:
|
2190 |
# shouldn't reach
|
2191 |
filter_kwargs = {}
|
2192 |
-
if langchain_mode in [LangChainMode.LLM.value
|
2193 |
docs = []
|
2194 |
scores = []
|
2195 |
-
elif document_subset ==
|
2196 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2197 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2198 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
@@ -2280,8 +2342,8 @@ def get_chain(query=None,
|
|
2280 |
docs_with_score.reverse()
|
2281 |
# cut off so no high distance docs/sources considered
|
2282 |
have_any_docs |= len(docs_with_score) > 0 # before cut
|
2283 |
-
docs = [x[0] for x in docs_with_score if x[1] <
|
2284 |
-
scores = [x[1] for x in docs_with_score if x[1] <
|
2285 |
if len(scores) > 0 and verbose:
|
2286 |
print("Distance: min: %s max: %s mean: %s median: %s" %
|
2287 |
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
@@ -2289,7 +2351,7 @@ def get_chain(query=None,
|
|
2289 |
docs = []
|
2290 |
scores = []
|
2291 |
|
2292 |
-
if not docs and
|
2293 |
# if HF type and have no docs, can bail out
|
2294 |
return docs, None, [], False, have_any_docs
|
2295 |
|
@@ -2312,7 +2374,7 @@ def get_chain(query=None,
|
|
2312 |
|
2313 |
if len(docs) == 0:
|
2314 |
# avoid context == in prompt then
|
2315 |
-
|
2316 |
template = template_if_no_docs
|
2317 |
|
2318 |
if langchain_action == LangChainAction.QUERY.value:
|
@@ -2328,7 +2390,7 @@ def get_chain(query=None,
|
|
2328 |
else:
|
2329 |
# only if use_openai_model = True, unused normally except in testing
|
2330 |
chain = load_qa_with_sources_chain(llm)
|
2331 |
-
if not
|
2332 |
chain_kwargs = dict(input_documents=[], question=query)
|
2333 |
else:
|
2334 |
chain_kwargs = dict(input_documents=docs, question=query)
|
@@ -2355,7 +2417,7 @@ def get_chain(query=None,
|
|
2355 |
else:
|
2356 |
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2357 |
|
2358 |
-
return docs, target, scores,
|
2359 |
|
2360 |
|
2361 |
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
|
|
24 |
from langchain.schema import LLMResult
|
25 |
from tqdm import tqdm
|
26 |
|
27 |
+
from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
|
28 |
+
LangChainAction, LangChainMode, DocumentChoice
|
29 |
from evaluate_params import gen_hyper
|
30 |
from gen import get_model, SEED
|
31 |
from prompter import non_hf_types, PromptType, Prompter
|
|
|
96 |
db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
97 |
hf_embedding_model, verbose=False)
|
98 |
if db is None:
|
99 |
+
from chromadb.config import Settings
|
100 |
+
client_settings = Settings(anonymized_telemetry=False,
|
101 |
+
chroma_db_impl="duckdb+parquet",
|
102 |
+
persist_directory=persist_directory)
|
103 |
db = Chroma.from_documents(documents=sources,
|
104 |
embedding=embedding,
|
105 |
persist_directory=persist_directory,
|
106 |
collection_name=collection_name,
|
107 |
+
client_settings=client_settings)
|
108 |
db.persist()
|
109 |
clear_embedding(db)
|
110 |
save_embed(db, use_openai_embedding, hf_embedding_model)
|
|
|
309 |
sanitize_bot_response: bool = False
|
310 |
|
311 |
prompter: Any = None
|
312 |
+
context: Any = ''
|
313 |
+
iinput: Any = ''
|
314 |
client: Any = None
|
315 |
|
316 |
class Config:
|
|
|
354 |
stream_output = self.stream
|
355 |
gr_client = self.client
|
356 |
client_langchain_mode = 'Disabled'
|
357 |
+
client_add_chat_history_to_context = True
|
358 |
client_langchain_action = LangChainAction.QUERY.value
|
359 |
client_langchain_agents = []
|
360 |
top_k_docs = 1
|
361 |
chunk = True
|
362 |
chunk_size = 512
|
363 |
client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
|
364 |
+
iinput=self.iinput if self.chat_client else '', # only for chat=True
|
365 |
+
context=self.context,
|
366 |
# streaming output is supported, loops over and outputs each generation in streaming mode
|
367 |
# but leave stream_output=False for simple input/output mode
|
368 |
stream_output=stream_output,
|
|
|
383 |
chat=self.chat_client,
|
384 |
|
385 |
instruction_nochat=prompt if not self.chat_client else '',
|
386 |
+
iinput_nochat=self.iinput if not self.chat_client else '',
|
387 |
langchain_mode=client_langchain_mode,
|
388 |
+
add_chat_history_to_context=client_add_chat_history_to_context,
|
389 |
langchain_action=client_langchain_action,
|
390 |
langchain_agents=client_langchain_agents,
|
391 |
top_k_docs=top_k_docs,
|
392 |
chunk=chunk,
|
393 |
chunk_size=chunk_size,
|
394 |
+
document_subset=DocumentSubset.Relevant.name,
|
395 |
+
document_choice=[DocumentChoice.ALL.value],
|
396 |
)
|
397 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
398 |
if not stream_output:
|
|
|
462 |
stream: bool = False
|
463 |
sanitize_bot_response: bool = False
|
464 |
prompter: Any = None
|
465 |
+
context: Any = ''
|
466 |
+
iinput: Any = ''
|
467 |
tokenizer: Any = None
|
468 |
client: Any = None
|
469 |
|
|
|
505 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
506 |
|
507 |
# NOTE: TGI server does not add prompting, so must do here
|
508 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
509 |
prompt = self.prompter.generate_prompt(data_point)
|
510 |
|
511 |
gen_server_kwargs = dict(do_sample=self.do_sample,
|
|
|
584 |
stop_sequences: Any = None
|
585 |
sanitize_bot_response: bool = False
|
586 |
prompter: Any = None
|
587 |
+
context: Any = ''
|
588 |
+
iinput: Any = ''
|
589 |
tokenizer: Any = None
|
590 |
|
591 |
@classmethod
|
|
|
611 |
for prompti, prompt in enumerate(prompts):
|
612 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
613 |
# NOTE: OpenAI/vLLM server does not add prompting, so must do here
|
614 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
615 |
prompt = self.prompter.generate_prompt(data_point)
|
616 |
prompts[prompti] = prompt
|
617 |
|
|
|
689 |
prompt_type=None,
|
690 |
prompt_dict=None,
|
691 |
prompter=None,
|
692 |
+
context=None,
|
693 |
+
iinput=None,
|
694 |
sanitize_bot_response=False,
|
695 |
verbose=False,
|
696 |
):
|
697 |
+
if inference_server is None:
|
698 |
+
inference_server = ''
|
699 |
if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
|
700 |
if use_openai_model and model_name is None:
|
701 |
model_name = "gpt-3.5-turbo"
|
702 |
+
# FIXME: Will later import be ignored? I think so, so should be fine
|
703 |
+
openai, inf_type = set_openai(inference_server)
|
704 |
kwargs_extra = {}
|
705 |
if inference_server == 'openai_chat' or inf_type == 'vllm_chat':
|
706 |
cls = H2OChatOpenAI
|
707 |
+
# FIXME: Support context, iinput
|
708 |
else:
|
709 |
cls = H2OOpenAI
|
710 |
if inf_type == 'vllm':
|
|
|
714 |
kwargs_extra = dict(stop_sequences=stop_sequences,
|
715 |
sanitize_bot_response=sanitize_bot_response,
|
716 |
prompter=prompter,
|
717 |
+
context=context,
|
718 |
+
iinput=iinput,
|
719 |
tokenizer=tokenizer,
|
720 |
client=None)
|
721 |
|
|
|
730 |
callbacks=callbacks if stream_output else None,
|
731 |
openai_api_key=openai.api_key,
|
732 |
openai_api_base=openai.api_base,
|
733 |
+
logit_bias=None if inf_type == 'vllm' else {},
|
734 |
max_retries=2,
|
735 |
streaming=stream_output,
|
736 |
**kwargs_extra
|
|
|
788 |
callbacks=callbacks if stream_output else None,
|
789 |
stream=stream_output,
|
790 |
prompter=prompter,
|
791 |
+
context=context,
|
792 |
+
iinput=iinput,
|
793 |
client=gr_client,
|
794 |
sanitize_bot_response=sanitize_bot_response,
|
795 |
)
|
|
|
810 |
callbacks=callbacks if stream_output else None,
|
811 |
stream=stream_output,
|
812 |
prompter=prompter,
|
813 |
+
context=context,
|
814 |
+
iinput=iinput,
|
815 |
tokenizer=tokenizer,
|
816 |
client=hf_client,
|
817 |
timeout=max_time,
|
|
|
844 |
verbose=verbose,
|
845 |
streaming=stream_output,
|
846 |
prompter=prompter,
|
847 |
+
context=context,
|
848 |
+
iinput=iinput,
|
849 |
)
|
850 |
else:
|
851 |
if model is None:
|
|
|
888 |
from h2oai_pipeline import H2OTextGenerationPipeline
|
889 |
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
|
890 |
prompter=prompter,
|
891 |
+
context=context,
|
892 |
+
iinpout=iinput,
|
893 |
prompt_type=prompt_type,
|
894 |
prompt_dict=prompt_dict,
|
895 |
sanitize_bot_response=sanitize_bot_response,
|
|
|
1075 |
is_url=False, is_txt=False,
|
1076 |
enable_captions=True,
|
1077 |
captions_model=None,
|
1078 |
+
enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None,
|
1079 |
headsize=50):
|
1080 |
if file is None:
|
1081 |
if fail_any_exception:
|
|
|
1092 |
base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
|
1093 |
base_path = os.path.join(dir_name, base_name)
|
1094 |
if is_url:
|
1095 |
+
file = file.strip() # in case accidental spaces in front or at end
|
1096 |
if file.lower().startswith('arxiv:'):
|
1097 |
query = file.lower().split('arxiv:')
|
1098 |
if len(query) == 2 and have_arxiv:
|
|
|
1244 |
from dotenv import dotenv_values
|
1245 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
1246 |
pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
|
1247 |
+
doc1 = []
|
1248 |
+
handled = False
|
1249 |
if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
|
1250 |
# GPL, only use if installed
|
1251 |
from langchain.document_loaders import PyMuPDFLoader
|
1252 |
# load() still chunks by pages, but every page has title at start to help
|
1253 |
doc1 = PyMuPDFLoader(file).load()
|
1254 |
+
# remove empty documents
|
1255 |
+
handled |= len(doc1) > 0
|
1256 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1257 |
doc1 = clean_doc(doc1)
|
1258 |
+
if len(doc1) == 0:
|
1259 |
doc1 = UnstructuredPDFLoader(file).load()
|
1260 |
+
handled |= len(doc1) > 0
|
1261 |
+
# remove empty documents
|
1262 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1263 |
# seems to not need cleaning in most cases
|
1264 |
+
if len(doc1) == 0:
|
1265 |
# open-source fallback
|
1266 |
# load() still chunks by pages, but every page has title at start to help
|
1267 |
doc1 = PyPDFLoader(file).load()
|
1268 |
+
handled |= len(doc1) > 0
|
1269 |
+
# remove empty documents
|
1270 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1271 |
doc1 = clean_doc(doc1)
|
1272 |
+
if have_pymupdf and len(doc1) == 0:
|
1273 |
+
# GPL, only use if installed
|
1274 |
+
from langchain.document_loaders import PyMuPDFLoader
|
1275 |
+
# load() still chunks by pages, but every page has title at start to help
|
1276 |
+
doc1 = PyMuPDFLoader(file).load()
|
1277 |
+
handled |= len(doc1) > 0
|
1278 |
+
# remove empty documents
|
1279 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1280 |
+
doc1 = clean_doc(doc1)
|
1281 |
+
if len(doc1) == 0 and enable_pdf_ocr == 'auto' or enable_pdf_ocr == 'on':
|
1282 |
+
# try OCR in end since slowest, but works on pure image pages well
|
1283 |
+
doc1 = UnstructuredPDFLoader(file, strategy='ocr_only').load()
|
1284 |
+
handled |= len(doc1) > 0
|
1285 |
+
# remove empty documents
|
1286 |
+
doc1 = [x for x in doc1 if x.page_content]
|
1287 |
+
# seems to not need cleaning in most cases
|
1288 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
1289 |
+
if len(doc1) == 0:
|
1290 |
+
# if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all.
|
1291 |
+
if handled:
|
1292 |
+
raise ValueError("%s had no valid text, but meta data was parsed" % file)
|
1293 |
+
else:
|
1294 |
+
raise ValueError("%s had no valid text and no meta data was parsed" % file)
|
1295 |
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
|
1296 |
add_meta(doc1, file)
|
1297 |
elif file.lower().endswith('.csv'):
|
|
|
1344 |
is_url=False, is_txt=False,
|
1345 |
enable_captions=True,
|
1346 |
captions_model=None,
|
1347 |
+
enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None):
|
1348 |
if verbose:
|
1349 |
if is_url:
|
1350 |
print("Ingesting URL: %s" % file, flush=True)
|
|
|
1362 |
enable_captions=enable_captions,
|
1363 |
captions_model=captions_model,
|
1364 |
enable_ocr=enable_ocr,
|
1365 |
+
enable_pdf_ocr=enable_pdf_ocr,
|
1366 |
caption_loader=caption_loader)
|
1367 |
except BaseException as e:
|
1368 |
print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
|
|
|
1371 |
else:
|
1372 |
exception_doc = Document(
|
1373 |
page_content='',
|
1374 |
+
metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)),
|
1375 |
"traceback": traceback.format_exc()})
|
1376 |
res = [exception_doc]
|
1377 |
if return_file:
|
|
|
1392 |
captions_model=None,
|
1393 |
caption_loader=None,
|
1394 |
enable_ocr=False,
|
1395 |
+
enable_pdf_ocr='auto',
|
1396 |
existing_files=[],
|
1397 |
existing_hash_ids={},
|
1398 |
):
|
|
|
1414 |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
1415 |
for ftype in non_image_types]
|
1416 |
else:
|
1417 |
+
if isinstance(path_or_paths, str):
|
1418 |
+
if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths):
|
1419 |
+
path_or_paths = [path_or_paths]
|
1420 |
+
else:
|
1421 |
+
# path was deleted etc.
|
1422 |
+
return []
|
1423 |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
1424 |
+
assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \
|
1425 |
+
"Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths))
|
1426 |
# reform out of allowed types
|
1427 |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
1428 |
# could do below:
|
|
|
1474 |
captions_model=captions_model,
|
1475 |
caption_loader=caption_loader,
|
1476 |
enable_ocr=enable_ocr,
|
1477 |
+
enable_pdf_ocr=enable_pdf_ocr,
|
1478 |
)
|
1479 |
|
1480 |
if n_jobs != 1 and len(globs_non_image_types) > 1:
|
|
|
1507 |
with open(fil, 'rb') as f:
|
1508 |
documents.extend(pickle.load(f))
|
1509 |
# remove temp pickle
|
1510 |
+
remove(fil)
|
1511 |
else:
|
1512 |
documents = reduce(concat, documents)
|
1513 |
return documents
|
|
|
1515 |
|
1516 |
def prep_langchain(persist_directory,
|
1517 |
load_db_if_exists,
|
1518 |
+
db_type, use_openai_embedding, langchain_mode, langchain_mode_paths,
|
1519 |
hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
|
1520 |
"""
|
1521 |
do prep first time, involving downloads
|
|
|
1525 |
assert langchain_mode not in ['MyData'], "Should not prep scratch data"
|
1526 |
|
1527 |
db_dir_exists = os.path.isdir(persist_directory)
|
1528 |
+
user_path = langchain_mode_paths.get(langchain_mode)
|
1529 |
|
1530 |
if db_dir_exists and user_path is None:
|
1531 |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
|
|
1661 |
langchain_kwargs[k] = defaults_db[k]
|
1662 |
# final check for missing
|
1663 |
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
|
1664 |
+
assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs
|
1665 |
# only keep actual used
|
1666 |
langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
|
1667 |
return _make_db(**langchain_kwargs)
|
|
|
1695 |
first_para=False, text_limit=None,
|
1696 |
chunk=True, chunk_size=512,
|
1697 |
langchain_mode=None,
|
1698 |
+
langchain_mode_paths=None,
|
1699 |
db_type='faiss',
|
1700 |
load_db_if_exists=True,
|
1701 |
db=None,
|
1702 |
n_jobs=-1,
|
1703 |
verbose=False):
|
1704 |
persist_directory = get_persist_directory(langchain_mode)
|
1705 |
+
user_path = langchain_mode_paths.get(langchain_mode)
|
1706 |
# see if can get persistent chroma db
|
1707 |
db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
1708 |
hf_embedding_model, verbose=verbose)
|
|
|
1710 |
db = db_trial
|
1711 |
|
1712 |
sources = []
|
1713 |
+
if not db:
|
1714 |
+
if langchain_mode in ['wiki_full']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1715 |
from read_wiki_full import get_all_documents
|
1716 |
small_test = None
|
1717 |
print("Generating new wiki", flush=True)
|
|
|
1721 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1722 |
print("Chunked new wiki", flush=True)
|
1723 |
sources.extend(sources1)
|
1724 |
+
elif langchain_mode in ['wiki']:
|
1725 |
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
|
1726 |
if chunk:
|
1727 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1728 |
sources.extend(sources1)
|
1729 |
+
elif langchain_mode in ['github h2oGPT']:
|
1730 |
# sources = get_github_docs("dagster-io", "dagster")
|
1731 |
sources1 = get_github_docs("h2oai", "h2ogpt")
|
1732 |
# FIXME: always chunk for now
|
1733 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1734 |
sources.extend(sources1)
|
1735 |
+
elif langchain_mode in ['DriverlessAI docs']:
|
1736 |
sources1 = get_dai_docs(from_hf=True)
|
1737 |
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
|
1738 |
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1739 |
sources.extend(sources1)
|
1740 |
+
if user_path:
|
1741 |
+
# UserData or custom, which has to be from user's disk
|
1742 |
+
if db is not None:
|
1743 |
+
# NOTE: Ignore file names for now, only go by hash ids
|
1744 |
+
# existing_files = get_existing_files(db)
|
1745 |
+
existing_files = []
|
1746 |
+
existing_hash_ids = get_existing_hash_ids(db)
|
1747 |
+
else:
|
1748 |
+
# pretend no existing files so won't filter
|
1749 |
+
existing_files = []
|
1750 |
+
existing_hash_ids = []
|
1751 |
+
# chunk internally for speed over multiple docs
|
1752 |
+
# FIXME: If first had old Hash=None and switch embeddings,
|
1753 |
+
# then re-embed, and then hit here and reload so have hash, and then re-embed.
|
1754 |
+
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
|
1755 |
+
existing_files=existing_files, existing_hash_ids=existing_hash_ids)
|
1756 |
+
new_metadata_sources = set([x.metadata['source'] for x in sources1])
|
1757 |
+
if new_metadata_sources:
|
1758 |
+
print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode),
|
1759 |
+
flush=True)
|
1760 |
+
if verbose:
|
1761 |
+
print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
|
1762 |
+
sources.extend(sources1)
|
1763 |
+
print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True)
|
1764 |
+
|
1765 |
+
# see if got sources
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1766 |
if not sources:
|
1767 |
if verbose:
|
1768 |
if db is not None:
|
|
|
1785 |
else:
|
1786 |
print("Did not generate db since no sources", flush=True)
|
1787 |
new_sources_metadata = [x.metadata for x in sources]
|
1788 |
+
elif user_path is not None:
|
1789 |
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
|
1790 |
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
|
1791 |
use_openai_embedding=use_openai_embedding,
|
|
|
1883 |
kwargs['answer_with_sources'] = True
|
1884 |
kwargs['show_rank'] = False
|
1885 |
missing_kwargs = [x for x in func_names if x not in kwargs]
|
1886 |
+
assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs
|
1887 |
# only keep actual used
|
1888 |
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
1889 |
try:
|
|
|
1897 |
context=None,
|
1898 |
use_openai_model=False, use_openai_embedding=False,
|
1899 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1900 |
+
langchain_mode_paths={},
|
1901 |
detect_user_path_changes_every_query=False,
|
1902 |
db_type='faiss',
|
1903 |
model_name=None, model=None, tokenizer=None, inference_server=None,
|
|
|
1907 |
prompt_type=None,
|
1908 |
prompt_dict=None,
|
1909 |
answer_with_sources=True,
|
1910 |
+
cut_distance=1.64,
|
1911 |
+
add_chat_history_to_context=True,
|
1912 |
sanitize_bot_response=False,
|
1913 |
show_rank=False,
|
1914 |
use_llm_if_no_docs=False,
|
|
|
1928 |
langchain_mode=None,
|
1929 |
langchain_action=None,
|
1930 |
langchain_agents=None,
|
1931 |
+
document_subset=DocumentSubset.Relevant.name,
|
1932 |
+
document_choice=[DocumentChoice.ALL.value],
|
1933 |
n_jobs=-1,
|
1934 |
verbose=False,
|
1935 |
cli=False,
|
|
|
1948 |
:param top_k_docs:
|
1949 |
:param chunk:
|
1950 |
:param chunk_size:
|
1951 |
+
:param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from
|
1952 |
:param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
|
1953 |
:param model_name: model name, used to switch behaviors
|
1954 |
:param model: pre-initialized model, else will make new one
|
|
|
1956 |
:param answer_with_sources
|
1957 |
:return:
|
1958 |
"""
|
1959 |
+
assert langchain_mode_paths is not None
|
1960 |
if model is not None:
|
1961 |
assert model_name is not None # require so can make decisions
|
1962 |
assert query is not None
|
|
|
1971 |
else:
|
1972 |
prompt_dict = ''
|
1973 |
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
|
1974 |
+
# pass in context to LLM directly, since already has prompt_type structure
|
1975 |
+
# can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638
|
1976 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1977 |
model=model,
|
1978 |
tokenizer=tokenizer,
|
|
|
1992 |
prompt_type=prompt_type,
|
1993 |
prompt_dict=prompt_dict,
|
1994 |
prompter=prompter,
|
1995 |
+
context=context if add_chat_history_to_context else '',
|
1996 |
+
iinput=iinput if add_chat_history_to_context else '',
|
1997 |
sanitize_bot_response=sanitize_bot_response,
|
1998 |
verbose=verbose,
|
1999 |
)
|
2000 |
|
2001 |
+
use_docs_planned = False
|
2002 |
scores = []
|
2003 |
chain = None
|
2004 |
|
|
|
2010 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
2011 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
2012 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
2013 |
+
docs, chain, scores, use_docs_planned, have_any_docs = get_chain(**sim_kwargs)
|
2014 |
if document_subset in non_query_commands:
|
2015 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
2016 |
+
if not formatted_doc_chunks and not use_llm_if_no_docs:
|
2017 |
+
yield "No sources", ''
|
2018 |
+
return
|
2019 |
+
# if no souces, outside gpt_langchain, LLM will be used with '' input
|
2020 |
yield formatted_doc_chunks, ''
|
2021 |
return
|
2022 |
if not use_llm_if_no_docs:
|
|
|
2028 |
yield ret, extra
|
2029 |
return
|
2030 |
if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
|
|
|
2031 |
LangChainMode.LLM.value]:
|
2032 |
ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
|
2033 |
extra = ''
|
|
|
2083 |
else:
|
2084 |
answer = chain()
|
2085 |
|
2086 |
+
if not use_docs_planned:
|
2087 |
ret = answer['output_text']
|
2088 |
extra = ''
|
2089 |
yield ret, extra
|
|
|
2095 |
|
2096 |
def get_chain(query=None,
|
2097 |
iinput=None,
|
2098 |
+
context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638
|
2099 |
use_openai_model=False, use_openai_embedding=False,
|
2100 |
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
2101 |
+
langchain_mode_paths=None,
|
2102 |
detect_user_path_changes_every_query=False,
|
2103 |
db_type='faiss',
|
2104 |
model_name=None,
|
|
|
2106 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
2107 |
prompt_type=None,
|
2108 |
prompt_dict=None,
|
2109 |
+
cut_distance=1.1,
|
2110 |
+
add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638
|
2111 |
load_db_if_exists=False,
|
2112 |
db=None,
|
2113 |
langchain_mode=None,
|
2114 |
langchain_action=None,
|
2115 |
langchain_agents=None,
|
2116 |
+
document_subset=DocumentSubset.Relevant.name,
|
2117 |
+
document_choice=[DocumentChoice.ALL.value],
|
2118 |
n_jobs=-1,
|
2119 |
# beyond run_db_query:
|
2120 |
llm=None,
|
|
|
2129 |
assert langchain_agents is not None # should be at least []
|
2130 |
# determine whether use of context out of docs is planned
|
2131 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2132 |
+
if langchain_mode in ['Disabled', 'LLM']:
|
2133 |
+
use_docs_planned = False
|
2134 |
else:
|
2135 |
+
use_docs_planned = True
|
2136 |
else:
|
2137 |
+
use_docs_planned = True
|
2138 |
|
2139 |
# https://github.com/hwchase17/langchain/issues/1946
|
2140 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
|
|
2151 |
# avoid looking at user_path during similarity search db handling,
|
2152 |
# if already have db and not updating from user_path every query
|
2153 |
# but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
|
2154 |
+
if langchain_mode_paths is None:
|
2155 |
+
langchain_mode_paths = {}
|
2156 |
+
langchain_mode_paths = langchain_mode_paths.copy()
|
2157 |
+
langchain_mode_paths[langchain_mode] = None
|
2158 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
|
2159 |
hf_embedding_model=hf_embedding_model,
|
2160 |
first_para=first_para, text_limit=text_limit,
|
2161 |
chunk=chunk,
|
2162 |
chunk_size=chunk_size,
|
2163 |
langchain_mode=langchain_mode,
|
2164 |
+
langchain_mode_paths=langchain_mode_paths,
|
2165 |
db_type=db_type,
|
2166 |
load_db_if_exists=load_db_if_exists,
|
2167 |
db=db,
|
|
|
2181 |
else:
|
2182 |
extra = ""
|
2183 |
prefix = ""
|
2184 |
+
if langchain_mode in ['Disabled', 'LLM'] or not use_docs_planned:
|
2185 |
template_if_no_docs = template = """%s{context}{question}""" % prefix
|
2186 |
else:
|
2187 |
template = """%s
|
|
|
2222 |
else:
|
2223 |
use_template = False
|
2224 |
|
2225 |
+
if db and use_docs_planned:
|
2226 |
base_path = 'locks'
|
2227 |
makedirs(base_path)
|
2228 |
if hasattr(db, '_persist_directory'):
|
|
|
2236 |
filter_kwargs = {}
|
2237 |
else:
|
2238 |
assert document_choice is not None, "Document choice was None"
|
2239 |
+
if len(document_choice) >= 1 and document_choice[0] == DocumentChoice.ALL.value:
|
2240 |
filter_kwargs = {}
|
2241 |
elif len(document_choice) >= 2:
|
2242 |
+
if document_choice[0] == DocumentChoice.ALL.value:
|
2243 |
# remove 'All'
|
2244 |
document_choice = document_choice[1:]
|
2245 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
|
|
2251 |
else:
|
2252 |
# shouldn't reach
|
2253 |
filter_kwargs = {}
|
2254 |
+
if langchain_mode in [LangChainMode.LLM.value]:
|
2255 |
docs = []
|
2256 |
scores = []
|
2257 |
+
elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
|
2258 |
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2259 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2260 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
|
|
2342 |
docs_with_score.reverse()
|
2343 |
# cut off so no high distance docs/sources considered
|
2344 |
have_any_docs |= len(docs_with_score) > 0 # before cut
|
2345 |
+
docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
|
2346 |
+
scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
|
2347 |
if len(scores) > 0 and verbose:
|
2348 |
print("Distance: min: %s max: %s mean: %s median: %s" %
|
2349 |
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
|
|
2351 |
docs = []
|
2352 |
scores = []
|
2353 |
|
2354 |
+
if not docs and use_docs_planned and model_name not in non_hf_types:
|
2355 |
# if HF type and have no docs, can bail out
|
2356 |
return docs, None, [], False, have_any_docs
|
2357 |
|
|
|
2374 |
|
2375 |
if len(docs) == 0:
|
2376 |
# avoid context == in prompt then
|
2377 |
+
use_docs_planned = False
|
2378 |
template = template_if_no_docs
|
2379 |
|
2380 |
if langchain_action == LangChainAction.QUERY.value:
|
|
|
2390 |
else:
|
2391 |
# only if use_openai_model = True, unused normally except in testing
|
2392 |
chain = load_qa_with_sources_chain(llm)
|
2393 |
+
if not use_docs_planned:
|
2394 |
chain_kwargs = dict(input_documents=[], question=query)
|
2395 |
else:
|
2396 |
chain_kwargs = dict(input_documents=docs, question=query)
|
|
|
2417 |
else:
|
2418 |
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
2419 |
|
2420 |
+
return docs, target, scores, use_docs_planned, have_any_docs
|
2421 |
|
2422 |
|
2423 |
def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
|
gradio_runner.py
CHANGED
@@ -50,16 +50,20 @@ def fix_pydantic_duplicate_validators_error():
|
|
50 |
|
51 |
fix_pydantic_duplicate_validators_error()
|
52 |
|
53 |
-
from enums import
|
|
|
54 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
|
55 |
text_xsm
|
56 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
57 |
get_prompt
|
58 |
-
from utils import
|
59 |
-
ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
from apscheduler.schedulers.background import BackgroundScheduler
|
65 |
|
@@ -94,12 +98,9 @@ def go_gradio(**kwargs):
|
|
94 |
memory_restriction_level = kwargs['memory_restriction_level']
|
95 |
n_gpus = kwargs['n_gpus']
|
96 |
admin_pass = kwargs['admin_pass']
|
97 |
-
model_state0 = kwargs['model_state0']
|
98 |
model_states = kwargs['model_states']
|
99 |
-
score_model_state0 = kwargs['score_model_state0']
|
100 |
dbs = kwargs['dbs']
|
101 |
db_type = kwargs['db_type']
|
102 |
-
visible_langchain_modes = kwargs['visible_langchain_modes']
|
103 |
visible_langchain_actions = kwargs['visible_langchain_actions']
|
104 |
visible_langchain_agents = kwargs['visible_langchain_agents']
|
105 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
@@ -112,8 +113,19 @@ def go_gradio(**kwargs):
|
|
112 |
enable_captions = kwargs['enable_captions']
|
113 |
captions_model = kwargs['captions_model']
|
114 |
enable_ocr = kwargs['enable_ocr']
|
|
|
115 |
caption_loader = kwargs['caption_loader']
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
# easy update of kwargs needed for evaluate() etc.
|
118 |
queue = True
|
119 |
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
@@ -133,25 +145,11 @@ def go_gradio(**kwargs):
|
|
133 |
" use Enter for multiple input lines)"
|
134 |
|
135 |
title = 'h2oGPT'
|
136 |
-
|
137 |
-
|
138 |
-
description = f"""Model {kwargs['base_model']} Instruct dataset.
|
139 |
-
For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
|
140 |
-
Command: {str(' '.join(sys.argv))}
|
141 |
-
Hash: {get_githash()}
|
142 |
-
"""
|
143 |
-
else:
|
144 |
-
description = more_info
|
145 |
-
description_bottom = "If this host is busy, try [Multi-Model](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
146 |
if is_hf:
|
147 |
description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
148 |
-
|
149 |
-
if kwargs['verbose']:
|
150 |
-
task_info_md = f"""
|
151 |
-
### Task: {kwargs['task_info']}"""
|
152 |
-
else:
|
153 |
-
task_info_md = ''
|
154 |
-
|
155 |
css_code = get_css(kwargs)
|
156 |
|
157 |
if kwargs['gradio_offline_level'] >= 0:
|
@@ -181,9 +179,9 @@ def go_gradio(**kwargs):
|
|
181 |
demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
|
182 |
callback = gr.CSVLogger()
|
183 |
|
184 |
-
|
185 |
-
if kwargs['base_model'].strip() not in
|
186 |
-
|
187 |
lora_options = kwargs['extra_lora_options']
|
188 |
if kwargs['lora_weights'].strip() not in lora_options:
|
189 |
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
@@ -198,7 +196,7 @@ def go_gradio(**kwargs):
|
|
198 |
|
199 |
# always add in no lora case
|
200 |
# add fake space so doesn't go away in gradio dropdown
|
201 |
-
|
202 |
lora_options = [no_lora_str] + lora_options
|
203 |
server_options = [no_server_str] + server_options
|
204 |
# always add in no model case so can free memory
|
@@ -252,6 +250,14 @@ def go_gradio(**kwargs):
|
|
252 |
# else gets input_list at time of submit that is old, and shows up as truncated in chatbot
|
253 |
return x
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
with demo:
|
256 |
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
257 |
# https://github.com/gradio-app/gradio/issues/3558
|
@@ -265,18 +271,32 @@ def go_gradio(**kwargs):
|
|
265 |
prompt_dict=kwargs['prompt_dict'],
|
266 |
)
|
267 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
model_state2 = gr.State(kwargs['model_state_none'].copy())
|
269 |
-
model_options_state = gr.State([
|
270 |
lora_options_state = gr.State([lora_options])
|
271 |
server_options_state = gr.State([server_options])
|
272 |
-
my_db_state = gr.State(
|
273 |
chat_state = gr.State({})
|
274 |
-
docs_state00 = kwargs['document_choice'] + [
|
275 |
docs_state0 = []
|
276 |
[docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
|
277 |
docs_state = gr.State(docs_state0)
|
278 |
viewable_docs_state0 = []
|
279 |
viewable_docs_state = gr.State(viewable_docs_state0)
|
|
|
|
|
|
|
280 |
gr.Markdown(f"""
|
281 |
{get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
|
282 |
""")
|
@@ -290,7 +310,7 @@ def go_gradio(**kwargs):
|
|
290 |
'model_lock'] else "Response Scores: %s" % nas
|
291 |
|
292 |
if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
|
293 |
-
extra_prompt_form = ". For summarization,
|
294 |
else:
|
295 |
extra_prompt_form = ""
|
296 |
if kwargs['input_lines'] > 1:
|
@@ -298,6 +318,34 @@ def go_gradio(**kwargs):
|
|
298 |
else:
|
299 |
instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
|
300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
normal_block = gr.Row(visible=not base_wanted, equal_height=False)
|
302 |
with normal_block:
|
303 |
side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
|
@@ -318,6 +366,7 @@ def go_gradio(**kwargs):
|
|
318 |
scale=1,
|
319 |
min_width=0,
|
320 |
elem_id="warning", elem_classes="feedback")
|
|
|
321 |
url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
|
322 |
url_label = 'URL/ArXiv' if have_arxiv else 'URL'
|
323 |
url_text = gr.Textbox(label=url_label,
|
@@ -331,29 +380,20 @@ def go_gradio(**kwargs):
|
|
331 |
visible=text_visible)
|
332 |
github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
|
333 |
database_visible = kwargs['langchain_mode'] != 'Disabled'
|
334 |
-
with gr.Accordion("
|
335 |
-
|
336 |
-
# don't show 'wiki' since only usually useful for internal testing at moment
|
337 |
-
no_show_modes = ['Disabled', 'wiki']
|
338 |
-
else:
|
339 |
-
no_show_modes = ['Disabled']
|
340 |
-
allowed_modes = visible_langchain_modes.copy()
|
341 |
-
allowed_modes = [x for x in allowed_modes if x in dbs]
|
342 |
-
allowed_modes += ['ChatLLM', 'LLM']
|
343 |
-
if allow_upload_to_my_data and 'MyData' not in allowed_modes:
|
344 |
-
allowed_modes += ['MyData']
|
345 |
-
if allow_upload_to_user_data and 'UserData' not in allowed_modes:
|
346 |
-
allowed_modes += ['UserData']
|
347 |
langchain_mode = gr.Radio(
|
348 |
-
|
349 |
value=kwargs['langchain_mode'],
|
350 |
label="Collections",
|
351 |
show_label=True,
|
352 |
visible=kwargs['langchain_mode'] != 'Disabled',
|
353 |
min_width=100)
|
354 |
-
|
|
|
|
|
355 |
label="Subset",
|
356 |
-
value=
|
357 |
interactive=True,
|
358 |
)
|
359 |
allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
|
@@ -417,9 +457,9 @@ def go_gradio(**kwargs):
|
|
417 |
mw1 = 50
|
418 |
mw2 = 50
|
419 |
with gr.Column(min_width=mw1):
|
420 |
-
submit = gr.Button(value='Submit', variant='primary',
|
421 |
min_width=mw1)
|
422 |
-
stop_btn = gr.Button(value="Stop", variant='secondary',
|
423 |
min_width=mw1)
|
424 |
save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
|
425 |
with gr.Column(min_width=mw2):
|
@@ -440,20 +480,50 @@ def go_gradio(**kwargs):
|
|
440 |
with gr.TabItem("Document Selection"):
|
441 |
document_choice = gr.Dropdown(docs_state0,
|
442 |
label="Select Subset of Document(s) %s" % file_types_str,
|
443 |
-
value=
|
444 |
interactive=True,
|
445 |
multiselect=True,
|
446 |
visible=kwargs['langchain_mode'] != 'Disabled',
|
447 |
)
|
448 |
sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
|
449 |
with gr.Row():
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
|
458 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
|
459 |
equal_height=False)
|
@@ -723,19 +793,20 @@ def go_gradio(**kwargs):
|
|
723 |
side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
|
724 |
submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
|
725 |
col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
|
726 |
-
text_outputs_height = gr.Slider(minimum=100, maximum=
|
727 |
-
step=
|
728 |
dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
|
729 |
with gr.Column(scale=4):
|
730 |
pass
|
|
|
731 |
admin_row = gr.Row()
|
732 |
with admin_row:
|
733 |
with gr.Column(scale=1):
|
734 |
-
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password',
|
735 |
-
|
736 |
with gr.Column(scale=4):
|
737 |
pass
|
738 |
-
system_row = gr.Row(visible=
|
739 |
with system_row:
|
740 |
with gr.Column():
|
741 |
with gr.Row():
|
@@ -799,23 +870,24 @@ def go_gradio(**kwargs):
|
|
799 |
else:
|
800 |
return tuple([gr.update(interactive=True)] * len(args))
|
801 |
|
802 |
-
# Add to UserData
|
803 |
update_db_func = functools.partial(update_user_db,
|
804 |
dbs=dbs,
|
805 |
db_type=db_type,
|
806 |
use_openai_embedding=use_openai_embedding,
|
807 |
hf_embedding_model=hf_embedding_model,
|
808 |
-
enable_captions=enable_captions,
|
809 |
captions_model=captions_model,
|
810 |
-
|
811 |
caption_loader=caption_loader,
|
|
|
|
|
812 |
verbose=kwargs['verbose'],
|
813 |
-
user_path=kwargs['user_path'],
|
814 |
n_jobs=kwargs['n_jobs'],
|
815 |
)
|
816 |
add_file_outputs = [fileup_output, langchain_mode]
|
817 |
add_file_kwargs = dict(fn=update_db_func,
|
818 |
-
inputs=[fileup_output, my_db_state, chunk, chunk_size,
|
|
|
819 |
outputs=add_file_outputs + [sources_text, doc_exception_text],
|
820 |
queue=queue,
|
821 |
api_name='add_file' if allow_api and allow_upload_to_user_data else None)
|
@@ -827,6 +899,15 @@ def go_gradio(**kwargs):
|
|
827 |
eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
|
828 |
show_progress='minimal')
|
829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
830 |
# note for update_user_db_func output is ignored for db
|
831 |
|
832 |
def clear_textbox():
|
@@ -836,7 +917,8 @@ def go_gradio(**kwargs):
|
|
836 |
|
837 |
add_url_outputs = [url_text, langchain_mode]
|
838 |
add_url_kwargs = dict(fn=update_user_db_url_func,
|
839 |
-
inputs=[url_text, my_db_state, chunk, chunk_size,
|
|
|
840 |
outputs=add_url_outputs + [sources_text, doc_exception_text],
|
841 |
queue=queue,
|
842 |
api_name='add_url' if allow_api and allow_upload_to_user_data else None)
|
@@ -853,7 +935,8 @@ def go_gradio(**kwargs):
|
|
853 |
update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
|
854 |
add_text_outputs = [user_text_text, langchain_mode]
|
855 |
add_text_kwargs = dict(fn=update_user_db_txt_func,
|
856 |
-
inputs=[user_text_text, my_db_state, chunk, chunk_size,
|
|
|
857 |
outputs=add_text_outputs + [sources_text, doc_exception_text],
|
858 |
queue=queue,
|
859 |
api_name='add_text' if allow_api and allow_upload_to_user_data else None
|
@@ -865,7 +948,7 @@ def go_gradio(**kwargs):
|
|
865 |
eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
|
866 |
eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
|
867 |
show_progress='minimal')
|
868 |
-
db_events = [eventdb1a, eventdb1, eventdb1b,
|
869 |
eventdb2a, eventdb2, eventdb2b, eventdb2c,
|
870 |
eventdb3a, eventdb3b, eventdb3, eventdb3c]
|
871 |
|
@@ -873,14 +956,14 @@ def go_gradio(**kwargs):
|
|
873 |
|
874 |
# if change collection source, must clear doc selections from it to avoid inconsistency
|
875 |
def clear_doc_choice():
|
876 |
-
return gr.Dropdown.update(choices=docs_state0, value=
|
877 |
|
878 |
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
|
879 |
|
880 |
def resize_col_tabs(x):
|
881 |
return gr.Dropdown.update(scale=x)
|
882 |
|
883 |
-
col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs)
|
884 |
|
885 |
def resize_chatbots(x, num_model_lock=0):
|
886 |
if num_model_lock == 0:
|
@@ -891,7 +974,7 @@ def go_gradio(**kwargs):
|
|
891 |
|
892 |
resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
|
893 |
text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
|
894 |
-
outputs=[text_output, text_output2] + text_outputs)
|
895 |
|
896 |
def update_dropdown(x):
|
897 |
return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
|
@@ -982,7 +1065,8 @@ def go_gradio(**kwargs):
|
|
982 |
if file.startswith('http') or file.startswith('https'):
|
983 |
# if file is online, then might as well use google(?)
|
984 |
document1 = file
|
985 |
-
return gr.update(visible=True,
|
|
|
986 |
</iframe>
|
987 |
"""), dummy1, dummy1, dummy1
|
988 |
else:
|
@@ -1005,9 +1089,11 @@ def go_gradio(**kwargs):
|
|
1005 |
|
1006 |
refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
|
1007 |
**get_kwargs(update_and_get_source_files_given_langchain_mode,
|
1008 |
-
exclude_names=['
|
|
|
1009 |
**all_kwargs))
|
1010 |
-
eventdb9 = refresh_sources_btn.click(fn=refresh_sources1,
|
|
|
1011 |
outputs=sources_text,
|
1012 |
api_name='refresh_sources' if allow_api else None)
|
1013 |
|
@@ -1017,9 +1103,153 @@ def go_gradio(**kwargs):
|
|
1017 |
def close_admin(x):
|
1018 |
return gr.update(visible=not (x == admin_pass))
|
1019 |
|
1020 |
-
|
1021 |
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
|
1022 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1023 |
inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1)
|
1024 |
inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2)
|
1025 |
from functools import partial
|
@@ -1031,11 +1261,11 @@ def go_gradio(**kwargs):
|
|
1031 |
def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
|
1032 |
args_list = list(args1)
|
1033 |
if str_api:
|
1034 |
-
user_kwargs = args_list[
|
1035 |
assert isinstance(user_kwargs, str)
|
1036 |
user_kwargs = ast.literal_eval(user_kwargs)
|
1037 |
else:
|
1038 |
-
user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[
|
1039 |
# only used for submit_nochat_api
|
1040 |
user_kwargs['chat'] = False
|
1041 |
if 'stream_output' not in user_kwargs:
|
@@ -1054,10 +1284,11 @@ def go_gradio(**kwargs):
|
|
1054 |
# correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
|
1055 |
model_state1 = args_list[0]
|
1056 |
my_db_state1 = args_list[1]
|
|
|
1057 |
args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
|
1058 |
in eval_func_param_names]
|
1059 |
assert len(args_list) == len(eval_func_param_names)
|
1060 |
-
args_list = [model_state1, my_db_state1] + args_list
|
1061 |
|
1062 |
try:
|
1063 |
for res_dict in evaluate(*tuple(args_list), **kwargs1):
|
@@ -1261,10 +1492,7 @@ def go_gradio(**kwargs):
|
|
1261 |
history[-1][1] = None
|
1262 |
return history
|
1263 |
if user_message1 in ['', None, '\n']:
|
1264 |
-
if
|
1265 |
-
DocumentChoices.All.name != document_subset1 \
|
1266 |
-
or \
|
1267 |
-
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1268 |
# reject non-retry submit/enter
|
1269 |
return history
|
1270 |
user_message1 = fix_text_for_gradio(user_message1)
|
@@ -1311,10 +1539,12 @@ def go_gradio(**kwargs):
|
|
1311 |
API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
|
1312 |
:return: last element is True if should run bot, False if should just yield history
|
1313 |
"""
|
|
|
1314 |
# don't deepcopy, can contain model itself
|
1315 |
args_list = list(args).copy()
|
1316 |
-
model_state1 = args_list[-
|
1317 |
-
my_db_state1 = args_list[-
|
|
|
1318 |
history = args_list[-1]
|
1319 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1320 |
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
@@ -1322,8 +1552,9 @@ def go_gradio(**kwargs):
|
|
1322 |
if model_state1['model'] is None or model_state1['model'] == no_model_str:
|
1323 |
return history, None, None, None
|
1324 |
|
1325 |
-
args_list = args_list[:-
|
1326 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
|
|
1327 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1328 |
langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
|
1329 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
@@ -1338,10 +1569,7 @@ def go_gradio(**kwargs):
|
|
1338 |
instruction1 = history[-1][0]
|
1339 |
history[-1][1] = None
|
1340 |
elif not instruction1:
|
1341 |
-
if
|
1342 |
-
DocumentChoices.All.name != document_choice1 \
|
1343 |
-
or \
|
1344 |
-
langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
|
1345 |
# if not retrying, then reject empty query
|
1346 |
return history, None, None, None
|
1347 |
elif len(history) > 0 and history[-1][1] not in [None, '']:
|
@@ -1358,7 +1586,9 @@ def go_gradio(**kwargs):
|
|
1358 |
|
1359 |
chat1 = args_list[eval_func_param_names.index('chat')]
|
1360 |
model_max_length1 = get_model_max_length(model_state1)
|
1361 |
-
context1 = history_to_context(history, langchain_mode1,
|
|
|
|
|
1362 |
model_max_length1, memory_restriction_level,
|
1363 |
kwargs['keep_sources_in_context'])
|
1364 |
args_list[0] = instruction1 # override original instruction with history from user
|
@@ -1367,6 +1597,7 @@ def go_gradio(**kwargs):
|
|
1367 |
fun1 = partial(evaluate,
|
1368 |
model_state1,
|
1369 |
my_db_state1,
|
|
|
1370 |
*tuple(args_list),
|
1371 |
**kwargs_evaluate)
|
1372 |
|
@@ -1412,24 +1643,26 @@ def go_gradio(**kwargs):
|
|
1412 |
clear_torch_cache()
|
1413 |
return
|
1414 |
|
1415 |
-
def clear_embeddings(langchain_mode1,
|
1416 |
# clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
|
1417 |
-
if db_type == 'chroma' and langchain_mode1 not in ['
|
1418 |
from gpt_langchain import clear_embedding
|
1419 |
db = dbs.get('langchain_mode1')
|
1420 |
if db is not None and not isinstance(db, str):
|
1421 |
clear_embedding(db)
|
1422 |
-
if
|
1423 |
-
|
|
|
|
|
1424 |
|
1425 |
def bot(*args, retry=False):
|
1426 |
-
history, fun1, langchain_mode1,
|
1427 |
try:
|
1428 |
for res in get_response(fun1, history):
|
1429 |
yield res
|
1430 |
finally:
|
1431 |
clear_torch_cache()
|
1432 |
-
clear_embeddings(langchain_mode1,
|
1433 |
|
1434 |
def all_bot(*args, retry=False, model_states1=None):
|
1435 |
args_list = list(args).copy()
|
@@ -1439,12 +1672,14 @@ def go_gradio(**kwargs):
|
|
1439 |
stream_output1 = args_list[eval_func_param_names.index('stream_output')]
|
1440 |
max_time1 = args_list[eval_func_param_names.index('max_time')]
|
1441 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1442 |
-
|
|
|
1443 |
try:
|
1444 |
gen_list = []
|
1445 |
for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
|
1446 |
args_list1 = args_list0.copy()
|
1447 |
-
args_list1.insert(-
|
|
|
1448 |
# if at start, have None in response still, replace with '' so client etc. acts like normal
|
1449 |
# assumes other parts of code treat '' and None as if no response yet from bot
|
1450 |
# can't do this later in bot code as racy with threaded generators
|
@@ -1454,8 +1689,8 @@ def go_gradio(**kwargs):
|
|
1454 |
# so consistent with prep_bot()
|
1455 |
# with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
|
1456 |
# langchain_mode1 and my_db_state1 should be same for every bot
|
1457 |
-
history, fun1, langchain_mode1,
|
1458 |
-
|
1459 |
gen1 = get_response(fun1, history)
|
1460 |
if stream_output1:
|
1461 |
gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
|
@@ -1501,7 +1736,7 @@ def go_gradio(**kwargs):
|
|
1501 |
print("Generate exceptions: %s" % exceptions, flush=True)
|
1502 |
finally:
|
1503 |
clear_torch_cache()
|
1504 |
-
clear_embeddings(langchain_mode1,
|
1505 |
|
1506 |
# NORMAL MODEL
|
1507 |
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
|
@@ -1509,11 +1744,11 @@ def go_gradio(**kwargs):
|
|
1509 |
outputs=text_output,
|
1510 |
)
|
1511 |
bot_args = dict(fn=bot,
|
1512 |
-
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
1513 |
outputs=[text_output, chat_exception_text],
|
1514 |
)
|
1515 |
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
1516 |
-
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
1517 |
outputs=[text_output, chat_exception_text],
|
1518 |
)
|
1519 |
retry_user_args = dict(fn=functools.partial(user, retry=True),
|
@@ -1531,11 +1766,11 @@ def go_gradio(**kwargs):
|
|
1531 |
outputs=text_output2,
|
1532 |
)
|
1533 |
bot_args2 = dict(fn=bot,
|
1534 |
-
inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
|
1535 |
outputs=[text_output2, chat_exception_text],
|
1536 |
)
|
1537 |
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
1538 |
-
inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
|
1539 |
outputs=[text_output2, chat_exception_text],
|
1540 |
)
|
1541 |
retry_user_args2 = dict(fn=functools.partial(user, retry=True),
|
@@ -1556,11 +1791,11 @@ def go_gradio(**kwargs):
|
|
1556 |
outputs=text_outputs,
|
1557 |
)
|
1558 |
all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
|
1559 |
-
inputs=inputs_list + [my_db_state] + text_outputs,
|
1560 |
outputs=text_outputs + [chat_exception_text],
|
1561 |
)
|
1562 |
all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
|
1563 |
-
inputs=inputs_list + [my_db_state] + text_outputs,
|
1564 |
outputs=text_outputs + [chat_exception_text],
|
1565 |
)
|
1566 |
all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
|
@@ -1722,6 +1957,11 @@ def go_gradio(**kwargs):
|
|
1722 |
def get_short_chat(x, short_chats, short_len=20, words=4):
|
1723 |
if x and len(x[0]) == 2 and x[0][0] is not None:
|
1724 |
short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
|
|
|
|
|
|
|
|
|
|
|
1725 |
short_chat = dedup(short_chat, short_chats)
|
1726 |
else:
|
1727 |
short_chat = None
|
@@ -1789,14 +2029,12 @@ def go_gradio(**kwargs):
|
|
1789 |
already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists])
|
1790 |
if not already_exists:
|
1791 |
chat_state1[short_chat] = chat_list.copy()
|
1792 |
-
|
1793 |
-
|
1794 |
-
|
1795 |
-
|
1796 |
-
|
1797 |
-
|
1798 |
-
ret_list = [chat_list] + [chat_state1]
|
1799 |
-
return tuple(ret_list)
|
1800 |
|
1801 |
def switch_chat(chat_key, chat_state1, num_model_lock=0):
|
1802 |
chosen_chat = chat_state1[chat_key]
|
@@ -1827,7 +2065,7 @@ def go_gradio(**kwargs):
|
|
1827 |
|
1828 |
remove_chat_event = remove_chat_btn.click(remove_chat,
|
1829 |
inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
|
1830 |
-
queue=False)
|
1831 |
|
1832 |
def get_chats1(chat_state1):
|
1833 |
base = 'chats'
|
@@ -1858,7 +2096,7 @@ def go_gradio(**kwargs):
|
|
1858 |
new_chats = json.loads(f.read())
|
1859 |
for chat1_k, chat1_v in new_chats.items():
|
1860 |
# ignore chat1_k, regenerate and de-dup to avoid loss
|
1861 |
-
|
1862 |
except BaseException as e:
|
1863 |
t, v, tb = sys.exc_info()
|
1864 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
@@ -1884,24 +2122,17 @@ def go_gradio(**kwargs):
|
|
1884 |
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
|
1885 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
1886 |
|
1887 |
-
def update_radio_chats(chat_state1):
|
1888 |
-
# reverse so newest at top
|
1889 |
-
choices = list(chat_state1.keys()).copy()
|
1890 |
-
choices.reverse()
|
1891 |
-
return gr.update(choices=choices, value=None)
|
1892 |
-
|
1893 |
clear_event = save_chat_btn.click(save_chat,
|
1894 |
inputs=[text_output, text_output2] + text_outputs + [chat_state],
|
1895 |
-
outputs=[
|
1896 |
-
api_name='save_chat' if allow_api else None)
|
1897 |
-
|
1898 |
-
|
1899 |
-
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
1900 |
|
1901 |
# NOTE: clear of instruction/iinput for nochat has to come after score,
|
1902 |
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
|
1903 |
no_chat_args = dict(fn=fun,
|
1904 |
-
inputs=[model_state, my_db_state] + inputs_list,
|
1905 |
outputs=text_output_nochat,
|
1906 |
queue=queue,
|
1907 |
)
|
@@ -1920,7 +2151,8 @@ def go_gradio(**kwargs):
|
|
1920 |
.then(clear_torch_cache)
|
1921 |
|
1922 |
submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str,
|
1923 |
-
inputs=[model_state, my_db_state,
|
|
|
1924 |
outputs=text_output_nochat_api,
|
1925 |
queue=True, # required for generator
|
1926 |
api_name='submit_nochat_api' if allow_api else None) \
|
@@ -2170,6 +2402,8 @@ def go_gradio(**kwargs):
|
|
2170 |
print("Exception: %s" % str(e), flush=True)
|
2171 |
return json.dumps(sys_dict)
|
2172 |
|
|
|
|
|
2173 |
get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs)
|
2174 |
|
2175 |
system_dict_event = system_btn2.click(get_system_info_dict_func,
|
@@ -2199,12 +2433,15 @@ def go_gradio(**kwargs):
|
|
2199 |
else:
|
2200 |
tokenizer = None
|
2201 |
if tokenizer is not None:
|
2202 |
-
langchain_mode1 = '
|
|
|
2203 |
# fake user message to mimic bot()
|
2204 |
chat1 = copy.deepcopy(chat1)
|
2205 |
chat1 = chat1 + [['user_message1', None]]
|
2206 |
model_max_length1 = tokenizer.model_max_length
|
2207 |
-
context1 = history_to_context(chat1, langchain_mode1,
|
|
|
|
|
2208 |
model_max_length1,
|
2209 |
memory_restriction_level1, keep_sources_in_context1)
|
2210 |
return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
|
@@ -2234,7 +2471,7 @@ def go_gradio(**kwargs):
|
|
2234 |
,
|
2235 |
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
|
2236 |
|
2237 |
-
demo.load(None, None, None, _js=get_dark_js() if kwargs['
|
2238 |
|
2239 |
demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
|
2240 |
favicon_path = "h2o-logo.svg"
|
@@ -2249,7 +2486,8 @@ def go_gradio(**kwargs):
|
|
2249 |
# FIXME: disable for gptj, langchain or gpt4all modify print itself
|
2250 |
# FIXME: and any multi-threaded/async print will enter model output!
|
2251 |
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
2252 |
-
|
|
|
2253 |
scheduler.start()
|
2254 |
|
2255 |
# import control
|
@@ -2268,9 +2506,6 @@ def go_gradio(**kwargs):
|
|
2268 |
demo.block_thread()
|
2269 |
|
2270 |
|
2271 |
-
input_args_list = ['model_state', 'my_db_state']
|
2272 |
-
|
2273 |
-
|
2274 |
def get_inputs_list(inputs_dict, model_lower, model_id=1):
|
2275 |
"""
|
2276 |
map gradio objects in locals() to inputs for evaluate().
|
@@ -2304,8 +2539,9 @@ def get_inputs_list(inputs_dict, model_lower, model_id=1):
|
|
2304 |
return inputs_list, inputs_dict_out
|
2305 |
|
2306 |
|
2307 |
-
def get_sources(
|
2308 |
-
|
|
|
2309 |
|
2310 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
2311 |
source_files_added = "NA"
|
@@ -2314,7 +2550,8 @@ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
|
|
2314 |
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
2315 |
" Ask jon.mckinney@h2o.ai for file if required."
|
2316 |
source_list = []
|
2317 |
-
elif langchain_mode
|
|
|
2318 |
from gpt_langchain import get_metadatas
|
2319 |
metadatas = get_metadatas(db1[0])
|
2320 |
source_list = sorted(set([x['source'] for x in metadatas]))
|
@@ -2345,14 +2582,13 @@ def set_userid(db1):
|
|
2345 |
db1[1] = str(uuid.uuid4())
|
2346 |
|
2347 |
|
2348 |
-
def update_user_db(file,
|
2349 |
-
|
2350 |
-
|
2351 |
if file is None:
|
2352 |
raise RuntimeError("Don't use change, use input")
|
2353 |
|
2354 |
try:
|
2355 |
-
return _update_user_db(file,
|
2356 |
langchain_mode=langchain_mode, dbs=dbs,
|
2357 |
**kwargs)
|
2358 |
except BaseException as e:
|
@@ -2383,25 +2619,30 @@ def get_lock_file(db1, langchain_mode):
|
|
2383 |
user_id = db1[1]
|
2384 |
base_path = 'locks'
|
2385 |
makedirs(base_path)
|
2386 |
-
lock_file = "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)
|
2387 |
return lock_file
|
2388 |
|
2389 |
|
2390 |
def _update_user_db(file,
|
2391 |
-
|
2392 |
chunk=None, chunk_size=None,
|
2393 |
-
dbs=None, db_type=None,
|
2394 |
-
|
|
|
|
|
|
|
2395 |
use_openai_embedding=None,
|
2396 |
hf_embedding_model=None,
|
2397 |
caption_loader=None,
|
2398 |
enable_captions=None,
|
2399 |
captions_model=None,
|
2400 |
enable_ocr=None,
|
|
|
2401 |
verbose=None,
|
|
|
2402 |
is_url=None, is_txt=None,
|
2403 |
-
|
2404 |
-
assert
|
2405 |
assert chunk is not None
|
2406 |
assert chunk_size is not None
|
2407 |
assert use_openai_embedding is not None
|
@@ -2410,10 +2651,9 @@ def _update_user_db(file,
|
|
2410 |
assert enable_captions is not None
|
2411 |
assert captions_model is not None
|
2412 |
assert enable_ocr is not None
|
|
|
2413 |
assert verbose is not None
|
2414 |
|
2415 |
-
set_userid(db1)
|
2416 |
-
|
2417 |
if dbs is None:
|
2418 |
dbs = {}
|
2419 |
assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
|
@@ -2431,17 +2671,22 @@ def _update_user_db(file,
|
|
2431 |
if langchain_mode == LangChainMode.DISABLED.value:
|
2432 |
return None, langchain_mode, get_source_files(), ""
|
2433 |
|
2434 |
-
if langchain_mode in [LangChainMode.
|
2435 |
# then switch to MyData, so langchain_mode also becomes way to select where upload goes
|
2436 |
# but default to mydata if nothing chosen, since safest
|
2437 |
-
|
2438 |
-
|
2439 |
-
|
|
|
|
|
|
|
|
|
|
|
2440 |
# move temp files from gradio upload to stable location
|
2441 |
for fili, fil in enumerate(file):
|
2442 |
-
if isinstance(fil, str):
|
2443 |
-
|
2444 |
-
|
2445 |
if os.path.isfile(new_fil):
|
2446 |
remove(new_fil)
|
2447 |
try:
|
@@ -2461,15 +2706,22 @@ def _update_user_db(file,
|
|
2461 |
enable_captions=enable_captions,
|
2462 |
captions_model=captions_model,
|
2463 |
enable_ocr=enable_ocr,
|
|
|
2464 |
caption_loader=caption_loader,
|
2465 |
)
|
2466 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2467 |
exceptions_strs = [x.metadata['exception'] for x in exceptions]
|
2468 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2469 |
|
2470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2471 |
with filelock.FileLock(lock_file):
|
2472 |
-
if langchain_mode
|
2473 |
if db1[0] is not None:
|
2474 |
# then add
|
2475 |
db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type,
|
@@ -2479,7 +2731,8 @@ def _update_user_db(file,
|
|
2479 |
# in testing expect:
|
2480 |
# assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
|
2481 |
# for production hit, when user gets clicky:
|
2482 |
-
assert len(db1) == 2, "Bad
|
|
|
2483 |
# then create
|
2484 |
# if added has to original state and didn't change, then would be shared db for all users
|
2485 |
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
@@ -2501,7 +2754,7 @@ def _update_user_db(file,
|
|
2501 |
use_openai_embedding=use_openai_embedding,
|
2502 |
hf_embedding_model=hf_embedding_model)
|
2503 |
else:
|
2504 |
-
# then create
|
2505 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
2506 |
db_type=db_type,
|
2507 |
persist_directory=persist_directory,
|
@@ -2515,14 +2768,15 @@ def _update_user_db(file,
|
|
2515 |
return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
|
2516 |
|
2517 |
|
2518 |
-
def get_db(
|
2519 |
-
|
|
|
2520 |
|
2521 |
with filelock.FileLock(lock_file):
|
2522 |
if langchain_mode in ['wiki_full']:
|
2523 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
2524 |
db = None
|
2525 |
-
elif langchain_mode
|
2526 |
db = db1[0]
|
2527 |
elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None:
|
2528 |
db = dbs[langchain_mode]
|
@@ -2531,8 +2785,8 @@ def get_db(db1, langchain_mode, dbs=None):
|
|
2531 |
return db
|
2532 |
|
2533 |
|
2534 |
-
def get_source_files_given_langchain_mode(
|
2535 |
-
db = get_db(
|
2536 |
if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
|
2537 |
return "Sources: N/A"
|
2538 |
return get_source_files(db=db, exceptions=None)
|
@@ -2631,11 +2885,19 @@ def get_source_files(db=None, exceptions=None, metadatas=None):
|
|
2631 |
return source_files_added
|
2632 |
|
2633 |
|
2634 |
-
def update_and_get_source_files_given_langchain_mode(
|
2635 |
-
|
2636 |
-
|
|
|
2637 |
n_jobs=None, verbose=None):
|
2638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2639 |
|
2640 |
from gpt_langchain import make_db
|
2641 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
|
@@ -2644,11 +2906,27 @@ def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=No
|
|
2644 |
chunk=chunk,
|
2645 |
chunk_size=chunk_size,
|
2646 |
langchain_mode=langchain_mode,
|
2647 |
-
|
2648 |
db_type=db_type,
|
2649 |
load_db_if_exists=load_db_if_exists,
|
2650 |
db=db,
|
2651 |
n_jobs=n_jobs,
|
2652 |
verbose=verbose)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2653 |
# return only new sources with text saying such
|
2654 |
return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
fix_pydantic_duplicate_validators_error()
|
52 |
|
53 |
+
from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \
|
54 |
+
DocumentChoice, langchain_modes_intrinsic
|
55 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
|
56 |
text_xsm
|
57 |
from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
|
58 |
get_prompt
|
59 |
+
from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
60 |
+
ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip, \
|
61 |
+
save_collection_names
|
62 |
+
from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, scratch_base_dir, \
|
63 |
+
get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \
|
64 |
+
update_langchain
|
65 |
+
from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \
|
66 |
+
input_args_list
|
67 |
|
68 |
from apscheduler.schedulers.background import BackgroundScheduler
|
69 |
|
|
|
98 |
memory_restriction_level = kwargs['memory_restriction_level']
|
99 |
n_gpus = kwargs['n_gpus']
|
100 |
admin_pass = kwargs['admin_pass']
|
|
|
101 |
model_states = kwargs['model_states']
|
|
|
102 |
dbs = kwargs['dbs']
|
103 |
db_type = kwargs['db_type']
|
|
|
104 |
visible_langchain_actions = kwargs['visible_langchain_actions']
|
105 |
visible_langchain_agents = kwargs['visible_langchain_agents']
|
106 |
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
|
|
113 |
enable_captions = kwargs['enable_captions']
|
114 |
captions_model = kwargs['captions_model']
|
115 |
enable_ocr = kwargs['enable_ocr']
|
116 |
+
enable_pdf_ocr = kwargs['enable_pdf_ocr']
|
117 |
caption_loader = kwargs['caption_loader']
|
118 |
|
119 |
+
# for dynamic state per user session in gradio
|
120 |
+
model_state0 = kwargs['model_state0']
|
121 |
+
score_model_state0 = kwargs['score_model_state0']
|
122 |
+
my_db_state0 = kwargs['my_db_state0']
|
123 |
+
selection_docs_state0 = kwargs['selection_docs_state0']
|
124 |
+
# for evaluate defaults
|
125 |
+
langchain_modes0 = kwargs['langchain_modes']
|
126 |
+
visible_langchain_modes0 = kwargs['visible_langchain_modes']
|
127 |
+
langchain_mode_paths0 = kwargs['langchain_mode_paths']
|
128 |
+
|
129 |
# easy update of kwargs needed for evaluate() etc.
|
130 |
queue = True
|
131 |
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
|
|
145 |
" use Enter for multiple input lines)"
|
146 |
|
147 |
title = 'h2oGPT'
|
148 |
+
description = """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="250" height="20" title="GitHub"></iframe><small><a href="https://github.com/h2oai/h2ogpt">h2oGPT</a> <a href="https://github.com/h2oai/h2o-llmstudio">H2O LLM Studio</a><br><a href="https://huggingface.co/h2oai">🤗 Models</a>"""
|
149 |
+
description_bottom = "If this host is busy, try<br>[Multi-Model](https://gpt.h2o.ai)<br>[Falcon 40B](https://falcon.h2o.ai)<br>[Vicuna 33B](https://wizardvicuna.h2o.ai)<br>[MPT 30B-Chat](https://mpt.h2o.ai)<br>[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)<br>[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
if is_hf:
|
151 |
description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
152 |
+
task_info_md = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
css_code = get_css(kwargs)
|
154 |
|
155 |
if kwargs['gradio_offline_level'] >= 0:
|
|
|
179 |
demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
|
180 |
callback = gr.CSVLogger()
|
181 |
|
182 |
+
model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
|
183 |
+
if kwargs['base_model'].strip() not in model_options0:
|
184 |
+
model_options0 = [kwargs['base_model'].strip()] + model_options0
|
185 |
lora_options = kwargs['extra_lora_options']
|
186 |
if kwargs['lora_weights'].strip() not in lora_options:
|
187 |
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
|
|
196 |
|
197 |
# always add in no lora case
|
198 |
# add fake space so doesn't go away in gradio dropdown
|
199 |
+
model_options0 = [no_model_str] + model_options0
|
200 |
lora_options = [no_lora_str] + lora_options
|
201 |
server_options = [no_server_str] + server_options
|
202 |
# always add in no model case so can free memory
|
|
|
250 |
# else gets input_list at time of submit that is old, and shows up as truncated in chatbot
|
251 |
return x
|
252 |
|
253 |
+
def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
|
254 |
+
allow = False
|
255 |
+
allow |= langchain_action1 not in LangChainAction.QUERY.value
|
256 |
+
allow |= document_subset1 in DocumentSubset.TopKSources.name
|
257 |
+
if langchain_mode1 in [LangChainMode.LLM.value]:
|
258 |
+
allow = False
|
259 |
+
return allow
|
260 |
+
|
261 |
with demo:
|
262 |
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
263 |
# https://github.com/gradio-app/gradio/issues/3558
|
|
|
271 |
prompt_dict=kwargs['prompt_dict'],
|
272 |
)
|
273 |
)
|
274 |
+
|
275 |
+
def update_langchain_mode_paths(db1s, selection_docs_state1):
|
276 |
+
if allow_upload_to_my_data:
|
277 |
+
selection_docs_state1['langchain_mode_paths'].update({k: None for k in db1s})
|
278 |
+
dup = selection_docs_state1['langchain_mode_paths'].copy()
|
279 |
+
for k, v in dup.items():
|
280 |
+
if k not in selection_docs_state1['visible_langchain_modes']:
|
281 |
+
selection_docs_state1['langchain_mode_paths'].pop(k)
|
282 |
+
return selection_docs_state1
|
283 |
+
|
284 |
+
# Setup some gradio states for per-user dynamic state
|
285 |
model_state2 = gr.State(kwargs['model_state_none'].copy())
|
286 |
+
model_options_state = gr.State([model_options0])
|
287 |
lora_options_state = gr.State([lora_options])
|
288 |
server_options_state = gr.State([server_options])
|
289 |
+
my_db_state = gr.State(my_db_state0)
|
290 |
chat_state = gr.State({})
|
291 |
+
docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value]
|
292 |
docs_state0 = []
|
293 |
[docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
|
294 |
docs_state = gr.State(docs_state0)
|
295 |
viewable_docs_state0 = []
|
296 |
viewable_docs_state = gr.State(viewable_docs_state0)
|
297 |
+
selection_docs_state0 = update_langchain_mode_paths(my_db_state0, selection_docs_state0)
|
298 |
+
selection_docs_state = gr.State(selection_docs_state0)
|
299 |
+
|
300 |
gr.Markdown(f"""
|
301 |
{get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
|
302 |
""")
|
|
|
310 |
'model_lock'] else "Response Scores: %s" % nas
|
311 |
|
312 |
if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
|
313 |
+
extra_prompt_form = ". For summarization, no query required, just click submit"
|
314 |
else:
|
315 |
extra_prompt_form = ""
|
316 |
if kwargs['input_lines'] > 1:
|
|
|
318 |
else:
|
319 |
instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
|
320 |
|
321 |
+
def get_langchain_choices(selection_docs_state1):
|
322 |
+
langchain_modes = selection_docs_state1['langchain_modes']
|
323 |
+
visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
|
324 |
+
|
325 |
+
if is_hf:
|
326 |
+
# don't show 'wiki' since only usually useful for internal testing at moment
|
327 |
+
no_show_modes = ['Disabled', 'wiki']
|
328 |
+
else:
|
329 |
+
no_show_modes = ['Disabled']
|
330 |
+
allowed_modes = visible_langchain_modes.copy()
|
331 |
+
# allowed_modes = [x for x in allowed_modes if x in dbs]
|
332 |
+
allowed_modes += ['LLM']
|
333 |
+
if allow_upload_to_my_data and 'MyData' not in allowed_modes:
|
334 |
+
allowed_modes += ['MyData']
|
335 |
+
if allow_upload_to_user_data and 'UserData' not in allowed_modes:
|
336 |
+
allowed_modes += ['UserData']
|
337 |
+
choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes]
|
338 |
+
return choices
|
339 |
+
|
340 |
+
def get_df_langchain_mode_paths(selection_docs_state1):
|
341 |
+
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
|
342 |
+
if langchain_mode_paths:
|
343 |
+
df = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns')
|
344 |
+
df.columns = ['Collection', 'Path']
|
345 |
+
else:
|
346 |
+
df = pd.DataFrame(None)
|
347 |
+
return df
|
348 |
+
|
349 |
normal_block = gr.Row(visible=not base_wanted, equal_height=False)
|
350 |
with normal_block:
|
351 |
side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
|
|
|
366 |
scale=1,
|
367 |
min_width=0,
|
368 |
elem_id="warning", elem_classes="feedback")
|
369 |
+
fileup_output_text = gr.Textbox(visible=False)
|
370 |
url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
|
371 |
url_label = 'URL/ArXiv' if have_arxiv else 'URL'
|
372 |
url_text = gr.Textbox(label=url_label,
|
|
|
380 |
visible=text_visible)
|
381 |
github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
|
382 |
database_visible = kwargs['langchain_mode'] != 'Disabled'
|
383 |
+
with gr.Accordion("Resources", open=False, visible=database_visible):
|
384 |
+
langchain_choices0 = get_langchain_choices(selection_docs_state0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
langchain_mode = gr.Radio(
|
386 |
+
langchain_choices0,
|
387 |
value=kwargs['langchain_mode'],
|
388 |
label="Collections",
|
389 |
show_label=True,
|
390 |
visible=kwargs['langchain_mode'] != 'Disabled',
|
391 |
min_width=100)
|
392 |
+
add_chat_history_to_context = gr.Checkbox(label="Chat History",
|
393 |
+
value=kwargs['add_chat_history_to_context'])
|
394 |
+
document_subset = gr.Radio([x.name for x in DocumentSubset],
|
395 |
label="Subset",
|
396 |
+
value=DocumentSubset.Relevant.name,
|
397 |
interactive=True,
|
398 |
)
|
399 |
allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
|
|
|
457 |
mw1 = 50
|
458 |
mw2 = 50
|
459 |
with gr.Column(min_width=mw1):
|
460 |
+
submit = gr.Button(value='Submit', variant='primary', size='sm',
|
461 |
min_width=mw1)
|
462 |
+
stop_btn = gr.Button(value="Stop", variant='secondary', size='sm',
|
463 |
min_width=mw1)
|
464 |
save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
|
465 |
with gr.Column(min_width=mw2):
|
|
|
480 |
with gr.TabItem("Document Selection"):
|
481 |
document_choice = gr.Dropdown(docs_state0,
|
482 |
label="Select Subset of Document(s) %s" % file_types_str,
|
483 |
+
value=[DocumentChoice.ALL.value],
|
484 |
interactive=True,
|
485 |
multiselect=True,
|
486 |
visible=kwargs['langchain_mode'] != 'Disabled',
|
487 |
)
|
488 |
sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
|
489 |
with gr.Row():
|
490 |
+
with gr.Column(scale=1):
|
491 |
+
get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
|
492 |
+
visible=sources_visible)
|
493 |
+
show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
|
494 |
+
visible=sources_visible)
|
495 |
+
refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
|
496 |
+
size='sm',
|
497 |
+
visible=sources_visible and allow_upload_to_user_data)
|
498 |
+
with gr.Column(scale=4):
|
499 |
+
pass
|
500 |
+
with gr.Row():
|
501 |
+
with gr.Column(scale=1):
|
502 |
+
add_placeholder = "e.g. UserData2, user_path2 (optional)" \
|
503 |
+
if not is_public else "e.g. MyData2"
|
504 |
+
remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
|
505 |
+
new_langchain_mode_text = gr.Textbox(value="", visible=allow_upload_to_user_data or
|
506 |
+
allow_upload_to_my_data,
|
507 |
+
label='Add Collection',
|
508 |
+
placeholder=add_placeholder,
|
509 |
+
interactive=True)
|
510 |
+
remove_langchain_mode_text = gr.Textbox(value="", visible=allow_upload_to_user_data or
|
511 |
+
allow_upload_to_my_data,
|
512 |
+
label='Remove Collection',
|
513 |
+
placeholder=remove_placeholder,
|
514 |
+
interactive=True)
|
515 |
+
load_langchain = gr.Button(value="Load LangChain State", scale=0, size='sm',
|
516 |
+
visible=allow_upload_to_user_data)
|
517 |
+
with gr.Column(scale=1):
|
518 |
+
df0 = get_df_langchain_mode_paths(selection_docs_state0)
|
519 |
+
langchain_mode_path_text = gr.Dataframe(value=df0,
|
520 |
+
visible=allow_upload_to_user_data or
|
521 |
+
allow_upload_to_my_data,
|
522 |
+
label='LangChain Mode-Path',
|
523 |
+
show_label=False,
|
524 |
+
interactive=False)
|
525 |
+
with gr.Column(scale=4):
|
526 |
+
pass
|
527 |
|
528 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
|
529 |
equal_height=False)
|
|
|
793 |
side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
|
794 |
submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
|
795 |
col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
|
796 |
+
text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400,
|
797 |
+
step=50, label='Chat Height')
|
798 |
dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
|
799 |
with gr.Column(scale=4):
|
800 |
pass
|
801 |
+
system_visible0 = not is_public and not admin_pass
|
802 |
admin_row = gr.Row()
|
803 |
with admin_row:
|
804 |
with gr.Column(scale=1):
|
805 |
+
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password',
|
806 |
+
visible=not system_visible0)
|
807 |
with gr.Column(scale=4):
|
808 |
pass
|
809 |
+
system_row = gr.Row(visible=system_visible0)
|
810 |
with system_row:
|
811 |
with gr.Column():
|
812 |
with gr.Row():
|
|
|
870 |
else:
|
871 |
return tuple([gr.update(interactive=True)] * len(args))
|
872 |
|
873 |
+
# Add to UserData or custom user db
|
874 |
update_db_func = functools.partial(update_user_db,
|
875 |
dbs=dbs,
|
876 |
db_type=db_type,
|
877 |
use_openai_embedding=use_openai_embedding,
|
878 |
hf_embedding_model=hf_embedding_model,
|
|
|
879 |
captions_model=captions_model,
|
880 |
+
enable_captions=enable_captions,
|
881 |
caption_loader=caption_loader,
|
882 |
+
enable_ocr=enable_ocr,
|
883 |
+
enable_pdf_ocr=enable_pdf_ocr,
|
884 |
verbose=kwargs['verbose'],
|
|
|
885 |
n_jobs=kwargs['n_jobs'],
|
886 |
)
|
887 |
add_file_outputs = [fileup_output, langchain_mode]
|
888 |
add_file_kwargs = dict(fn=update_db_func,
|
889 |
+
inputs=[fileup_output, my_db_state, selection_docs_state, chunk, chunk_size,
|
890 |
+
langchain_mode],
|
891 |
outputs=add_file_outputs + [sources_text, doc_exception_text],
|
892 |
queue=queue,
|
893 |
api_name='add_file' if allow_api and allow_upload_to_user_data else None)
|
|
|
899 |
eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
|
900 |
show_progress='minimal')
|
901 |
|
902 |
+
# deal with challenge to have fileup_output itself as input
|
903 |
+
add_file_kwargs2 = dict(fn=update_db_func,
|
904 |
+
inputs=[fileup_output_text, my_db_state, selection_docs_state, chunk, chunk_size,
|
905 |
+
langchain_mode],
|
906 |
+
outputs=add_file_outputs + [sources_text, doc_exception_text],
|
907 |
+
queue=queue,
|
908 |
+
api_name='add_file_api' if allow_api and allow_upload_to_user_data else None)
|
909 |
+
eventdb1_api = fileup_output_text.submit(**add_file_kwargs2, show_progress='full')
|
910 |
+
|
911 |
# note for update_user_db_func output is ignored for db
|
912 |
|
913 |
def clear_textbox():
|
|
|
917 |
|
918 |
add_url_outputs = [url_text, langchain_mode]
|
919 |
add_url_kwargs = dict(fn=update_user_db_url_func,
|
920 |
+
inputs=[url_text, my_db_state, selection_docs_state, chunk, chunk_size,
|
921 |
+
langchain_mode],
|
922 |
outputs=add_url_outputs + [sources_text, doc_exception_text],
|
923 |
queue=queue,
|
924 |
api_name='add_url' if allow_api and allow_upload_to_user_data else None)
|
|
|
935 |
update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
|
936 |
add_text_outputs = [user_text_text, langchain_mode]
|
937 |
add_text_kwargs = dict(fn=update_user_db_txt_func,
|
938 |
+
inputs=[user_text_text, my_db_state, selection_docs_state, chunk, chunk_size,
|
939 |
+
langchain_mode],
|
940 |
outputs=add_text_outputs + [sources_text, doc_exception_text],
|
941 |
queue=queue,
|
942 |
api_name='add_text' if allow_api and allow_upload_to_user_data else None
|
|
|
948 |
eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
|
949 |
eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
|
950 |
show_progress='minimal')
|
951 |
+
db_events = [eventdb1a, eventdb1, eventdb1b, eventdb1_api,
|
952 |
eventdb2a, eventdb2, eventdb2b, eventdb2c,
|
953 |
eventdb3a, eventdb3b, eventdb3, eventdb3c]
|
954 |
|
|
|
956 |
|
957 |
# if change collection source, must clear doc selections from it to avoid inconsistency
|
958 |
def clear_doc_choice():
|
959 |
+
return gr.Dropdown.update(choices=docs_state0, value=DocumentChoice.ALL.value)
|
960 |
|
961 |
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
|
962 |
|
963 |
def resize_col_tabs(x):
|
964 |
return gr.Dropdown.update(scale=x)
|
965 |
|
966 |
+
col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs, queue=False)
|
967 |
|
968 |
def resize_chatbots(x, num_model_lock=0):
|
969 |
if num_model_lock == 0:
|
|
|
974 |
|
975 |
resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
|
976 |
text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
|
977 |
+
outputs=[text_output, text_output2] + text_outputs, queue=False)
|
978 |
|
979 |
def update_dropdown(x):
|
980 |
return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
|
|
|
1065 |
if file.startswith('http') or file.startswith('https'):
|
1066 |
# if file is online, then might as well use google(?)
|
1067 |
document1 = file
|
1068 |
+
return gr.update(visible=True,
|
1069 |
+
value=f"""<iframe width="1000" height="800" src="https://docs.google.com/viewerng/viewer?url={document1}&embedded=true" frameborder="0" height="100%" width="100%">
|
1070 |
</iframe>
|
1071 |
"""), dummy1, dummy1, dummy1
|
1072 |
else:
|
|
|
1089 |
|
1090 |
refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
|
1091 |
**get_kwargs(update_and_get_source_files_given_langchain_mode,
|
1092 |
+
exclude_names=['db1s', 'langchain_mode', 'chunk',
|
1093 |
+
'chunk_size'],
|
1094 |
**all_kwargs))
|
1095 |
+
eventdb9 = refresh_sources_btn.click(fn=refresh_sources1,
|
1096 |
+
inputs=[my_db_state, langchain_mode, chunk, chunk_size],
|
1097 |
outputs=sources_text,
|
1098 |
api_name='refresh_sources' if allow_api else None)
|
1099 |
|
|
|
1103 |
def close_admin(x):
|
1104 |
return gr.update(visible=not (x == admin_pass))
|
1105 |
|
1106 |
+
admin_pass_textbox.submit(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
|
1107 |
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
|
1108 |
|
1109 |
+
def add_langchain_mode(db1s, selection_docs_state1, langchain_mode1, y):
|
1110 |
+
for k in db1s:
|
1111 |
+
set_userid(db1s[k])
|
1112 |
+
langchain_modes = selection_docs_state1['langchain_modes']
|
1113 |
+
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
|
1114 |
+
visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
|
1115 |
+
|
1116 |
+
user_path = None
|
1117 |
+
valid = True
|
1118 |
+
y2 = y.strip().replace(' ', '').split(',')
|
1119 |
+
if len(y2) >= 1:
|
1120 |
+
langchain_mode2 = y2[0]
|
1121 |
+
if len(langchain_mode2) >= 3 and langchain_mode2.isalnum():
|
1122 |
+
# real restriction is:
|
1123 |
+
# ValueError: Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address, got me
|
1124 |
+
# but just make simpler
|
1125 |
+
user_path = y2[1] if len(y2) > 1 else None # assume scratch if don't have user_path
|
1126 |
+
if user_path in ['', "''"]:
|
1127 |
+
# for scratch spaces
|
1128 |
+
user_path = None
|
1129 |
+
if langchain_mode2 in langchain_modes_intrinsic:
|
1130 |
+
user_path = None
|
1131 |
+
textbox = "Invalid access to use internal name: %s" % langchain_mode2
|
1132 |
+
valid = False
|
1133 |
+
langchain_mode2 = langchain_mode1
|
1134 |
+
elif user_path and allow_upload_to_user_data or not user_path and allow_upload_to_my_data:
|
1135 |
+
langchain_mode_paths.update({langchain_mode2: user_path})
|
1136 |
+
if langchain_mode2 not in visible_langchain_modes:
|
1137 |
+
visible_langchain_modes.append(langchain_mode2)
|
1138 |
+
if langchain_mode2 not in langchain_modes:
|
1139 |
+
langchain_modes.append(langchain_mode2)
|
1140 |
+
textbox = ''
|
1141 |
+
if user_path:
|
1142 |
+
makedirs(user_path, exist_ok=True)
|
1143 |
+
else:
|
1144 |
+
valid = False
|
1145 |
+
langchain_mode2 = langchain_mode1
|
1146 |
+
textbox = "Invalid access. user allowed: %s " \
|
1147 |
+
"scratch allowed: %s" % (allow_upload_to_user_data, allow_upload_to_my_data)
|
1148 |
+
else:
|
1149 |
+
valid = False
|
1150 |
+
langchain_mode2 = langchain_mode1
|
1151 |
+
textbox = "Invalid, collection must be >=3 characters and alphanumeric"
|
1152 |
+
else:
|
1153 |
+
valid = False
|
1154 |
+
langchain_mode2 = langchain_mode1
|
1155 |
+
textbox = "Invalid, must be like UserData2, user_path2"
|
1156 |
+
selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
|
1157 |
+
df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
|
1158 |
+
choices = get_langchain_choices(selection_docs_state1)
|
1159 |
+
|
1160 |
+
if valid and not user_path:
|
1161 |
+
# needs to have key for it to make it known different from userdata case in _update_user_db()
|
1162 |
+
db1s[langchain_mode2] = [None, None]
|
1163 |
+
if valid:
|
1164 |
+
save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode,
|
1165 |
+
db1s)
|
1166 |
+
|
1167 |
+
return db1s, selection_docs_state1, gr.update(choices=choices,
|
1168 |
+
value=langchain_mode2), textbox, df_langchain_mode_paths1
|
1169 |
+
|
1170 |
+
def remove_langchain_mode(db1s, selection_docs_state1, langchain_mode1, langchain_mode2, dbsu=None):
|
1171 |
+
for k in db1s:
|
1172 |
+
set_userid(db1s[k])
|
1173 |
+
assert dbsu is not None
|
1174 |
+
langchain_modes = selection_docs_state1['langchain_modes']
|
1175 |
+
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
|
1176 |
+
visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
|
1177 |
+
|
1178 |
+
if langchain_mode2 in db1s and not allow_upload_to_my_data or \
|
1179 |
+
dbsu is not None and langchain_mode2 in dbsu and not allow_upload_to_user_data or \
|
1180 |
+
langchain_mode2 in langchain_modes_intrinsic:
|
1181 |
+
# NOTE: Doesn't fail if remove MyData, but didn't debug odd behavior seen with upload after gone
|
1182 |
+
textbox = "Invalid access, cannot remove %s" % langchain_mode2
|
1183 |
+
df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
|
1184 |
+
else:
|
1185 |
+
# change global variables
|
1186 |
+
if langchain_mode2 in visible_langchain_modes:
|
1187 |
+
visible_langchain_modes.remove(langchain_mode2)
|
1188 |
+
textbox = ""
|
1189 |
+
else:
|
1190 |
+
textbox = "%s was not visible" % langchain_mode2
|
1191 |
+
if langchain_mode2 in langchain_modes:
|
1192 |
+
langchain_modes.remove(langchain_mode2)
|
1193 |
+
if langchain_mode2 in langchain_mode_paths:
|
1194 |
+
langchain_mode_paths.pop(langchain_mode2)
|
1195 |
+
if langchain_mode2 in db1s:
|
1196 |
+
# remove db entirely, so not in list, else need to manage visible list in update_langchain_mode_paths()
|
1197 |
+
# FIXME: Remove location?
|
1198 |
+
if langchain_mode2 != LangChainMode.MY_DATA.value:
|
1199 |
+
# don't remove last MyData, used as user hash
|
1200 |
+
db1s.pop(langchain_mode2)
|
1201 |
+
# only show
|
1202 |
+
selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
|
1203 |
+
df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
|
1204 |
+
|
1205 |
+
save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode,
|
1206 |
+
db1s)
|
1207 |
+
|
1208 |
+
return db1s, selection_docs_state1, \
|
1209 |
+
gr.update(choices=get_langchain_choices(selection_docs_state1),
|
1210 |
+
value=langchain_mode2), textbox, df_langchain_mode_paths1
|
1211 |
+
|
1212 |
+
new_langchain_mode_text.submit(fn=add_langchain_mode,
|
1213 |
+
inputs=[my_db_state, selection_docs_state, langchain_mode,
|
1214 |
+
new_langchain_mode_text],
|
1215 |
+
outputs=[my_db_state, selection_docs_state, langchain_mode,
|
1216 |
+
new_langchain_mode_text,
|
1217 |
+
langchain_mode_path_text],
|
1218 |
+
api_name='new_langchain_mode_text' if allow_api and allow_upload_to_user_data else None)
|
1219 |
+
remove_langchain_mode_func = functools.partial(remove_langchain_mode, dbsu=dbs)
|
1220 |
+
remove_langchain_mode_text.submit(fn=remove_langchain_mode_func,
|
1221 |
+
inputs=[my_db_state, selection_docs_state, langchain_mode,
|
1222 |
+
remove_langchain_mode_text],
|
1223 |
+
outputs=[my_db_state, selection_docs_state, langchain_mode,
|
1224 |
+
remove_langchain_mode_text,
|
1225 |
+
langchain_mode_path_text],
|
1226 |
+
api_name='remove_langchain_mode_text' if allow_api and allow_upload_to_user_data else None)
|
1227 |
+
|
1228 |
+
def update_langchain_gr(db1s, selection_docs_state1, langchain_mode1):
|
1229 |
+
for k in db1s:
|
1230 |
+
set_userid(db1s[k])
|
1231 |
+
langchain_modes = selection_docs_state1['langchain_modes']
|
1232 |
+
langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
|
1233 |
+
visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
|
1234 |
+
# in-place
|
1235 |
+
|
1236 |
+
# update user collaborative collections
|
1237 |
+
update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '')
|
1238 |
+
# update scratch single-user collections
|
1239 |
+
user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
|
1240 |
+
update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, user_hash)
|
1241 |
+
|
1242 |
+
selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
|
1243 |
+
df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
|
1244 |
+
return selection_docs_state1, \
|
1245 |
+
gr.update(choices=get_langchain_choices(selection_docs_state1),
|
1246 |
+
value=langchain_mode1), df_langchain_mode_paths1
|
1247 |
+
|
1248 |
+
load_langchain.click(fn=update_langchain_gr,
|
1249 |
+
inputs=[my_db_state, selection_docs_state, langchain_mode],
|
1250 |
+
outputs=[selection_docs_state, langchain_mode, langchain_mode_path_text],
|
1251 |
+
api_name='load_langchain' if allow_api and allow_upload_to_user_data else None)
|
1252 |
+
|
1253 |
inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1)
|
1254 |
inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2)
|
1255 |
from functools import partial
|
|
|
1261 |
def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
|
1262 |
args_list = list(args1)
|
1263 |
if str_api:
|
1264 |
+
user_kwargs = args_list[len(input_args_list)]
|
1265 |
assert isinstance(user_kwargs, str)
|
1266 |
user_kwargs = ast.literal_eval(user_kwargs)
|
1267 |
else:
|
1268 |
+
user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[len(input_args_list):])}
|
1269 |
# only used for submit_nochat_api
|
1270 |
user_kwargs['chat'] = False
|
1271 |
if 'stream_output' not in user_kwargs:
|
|
|
1284 |
# correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
|
1285 |
model_state1 = args_list[0]
|
1286 |
my_db_state1 = args_list[1]
|
1287 |
+
selection_docs_state1 = args_list[2]
|
1288 |
args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
|
1289 |
in eval_func_param_names]
|
1290 |
assert len(args_list) == len(eval_func_param_names)
|
1291 |
+
args_list = [model_state1, my_db_state1, selection_docs_state1] + args_list
|
1292 |
|
1293 |
try:
|
1294 |
for res_dict in evaluate(*tuple(args_list), **kwargs1):
|
|
|
1492 |
history[-1][1] = None
|
1493 |
return history
|
1494 |
if user_message1 in ['', None, '\n']:
|
1495 |
+
if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
|
|
|
|
|
|
|
1496 |
# reject non-retry submit/enter
|
1497 |
return history
|
1498 |
user_message1 = fix_text_for_gradio(user_message1)
|
|
|
1539 |
API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
|
1540 |
:return: last element is True if should run bot, False if should just yield history
|
1541 |
"""
|
1542 |
+
isize = len(input_args_list) + 1 # states + chat history
|
1543 |
# don't deepcopy, can contain model itself
|
1544 |
args_list = list(args).copy()
|
1545 |
+
model_state1 = args_list[-isize]
|
1546 |
+
my_db_state1 = args_list[-isize + 1]
|
1547 |
+
selection_docs_state1 = args_list[-isize + 2]
|
1548 |
history = args_list[-1]
|
1549 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
1550 |
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
|
|
1552 |
if model_state1['model'] is None or model_state1['model'] == no_model_str:
|
1553 |
return history, None, None, None
|
1554 |
|
1555 |
+
args_list = args_list[:-isize] # only keep rest needed for evaluate()
|
1556 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1557 |
+
add_chat_history_to_context1 = args_list[eval_func_param_names.index('add_chat_history_to_context')]
|
1558 |
langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
|
1559 |
langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
|
1560 |
document_subset1 = args_list[eval_func_param_names.index('document_subset')]
|
|
|
1569 |
instruction1 = history[-1][0]
|
1570 |
history[-1][1] = None
|
1571 |
elif not instruction1:
|
1572 |
+
if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
|
|
|
|
|
|
|
1573 |
# if not retrying, then reject empty query
|
1574 |
return history, None, None, None
|
1575 |
elif len(history) > 0 and history[-1][1] not in [None, '']:
|
|
|
1586 |
|
1587 |
chat1 = args_list[eval_func_param_names.index('chat')]
|
1588 |
model_max_length1 = get_model_max_length(model_state1)
|
1589 |
+
context1 = history_to_context(history, langchain_mode1,
|
1590 |
+
add_chat_history_to_context1,
|
1591 |
+
prompt_type1, prompt_dict1, chat1,
|
1592 |
model_max_length1, memory_restriction_level,
|
1593 |
kwargs['keep_sources_in_context'])
|
1594 |
args_list[0] = instruction1 # override original instruction with history from user
|
|
|
1597 |
fun1 = partial(evaluate,
|
1598 |
model_state1,
|
1599 |
my_db_state1,
|
1600 |
+
selection_docs_state1,
|
1601 |
*tuple(args_list),
|
1602 |
**kwargs_evaluate)
|
1603 |
|
|
|
1643 |
clear_torch_cache()
|
1644 |
return
|
1645 |
|
1646 |
+
def clear_embeddings(langchain_mode1, db1s):
|
1647 |
# clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
|
1648 |
+
if db_type == 'chroma' and langchain_mode1 not in ['LLM', 'Disabled', None, '']:
|
1649 |
from gpt_langchain import clear_embedding
|
1650 |
db = dbs.get('langchain_mode1')
|
1651 |
if db is not None and not isinstance(db, str):
|
1652 |
clear_embedding(db)
|
1653 |
+
if db1s is not None and langchain_mode1 in db1s:
|
1654 |
+
db1 = db1s[langchain_mode1]
|
1655 |
+
if len(db1) == 2:
|
1656 |
+
clear_embedding(db1[0])
|
1657 |
|
1658 |
def bot(*args, retry=False):
|
1659 |
+
history, fun1, langchain_mode1, db1 = prep_bot(*args, retry=retry)
|
1660 |
try:
|
1661 |
for res in get_response(fun1, history):
|
1662 |
yield res
|
1663 |
finally:
|
1664 |
clear_torch_cache()
|
1665 |
+
clear_embeddings(langchain_mode1, db1)
|
1666 |
|
1667 |
def all_bot(*args, retry=False, model_states1=None):
|
1668 |
args_list = list(args).copy()
|
|
|
1672 |
stream_output1 = args_list[eval_func_param_names.index('stream_output')]
|
1673 |
max_time1 = args_list[eval_func_param_names.index('max_time')]
|
1674 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
1675 |
+
isize = len(input_args_list) + 1 # states + chat history
|
1676 |
+
db1s = None
|
1677 |
try:
|
1678 |
gen_list = []
|
1679 |
for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
|
1680 |
args_list1 = args_list0.copy()
|
1681 |
+
args_list1.insert(-isize + 2,
|
1682 |
+
model_state1) # insert at -2 so is at -3, and after chatbot1 added, at -4
|
1683 |
# if at start, have None in response still, replace with '' so client etc. acts like normal
|
1684 |
# assumes other parts of code treat '' and None as if no response yet from bot
|
1685 |
# can't do this later in bot code as racy with threaded generators
|
|
|
1689 |
# so consistent with prep_bot()
|
1690 |
# with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
|
1691 |
# langchain_mode1 and my_db_state1 should be same for every bot
|
1692 |
+
history, fun1, langchain_mode1, db1s = prep_bot(*tuple(args_list1), retry=retry,
|
1693 |
+
which_model=chatboti)
|
1694 |
gen1 = get_response(fun1, history)
|
1695 |
if stream_output1:
|
1696 |
gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
|
|
|
1736 |
print("Generate exceptions: %s" % exceptions, flush=True)
|
1737 |
finally:
|
1738 |
clear_torch_cache()
|
1739 |
+
clear_embeddings(langchain_mode1, db1s)
|
1740 |
|
1741 |
# NORMAL MODEL
|
1742 |
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
|
|
|
1744 |
outputs=text_output,
|
1745 |
)
|
1746 |
bot_args = dict(fn=bot,
|
1747 |
+
inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output],
|
1748 |
outputs=[text_output, chat_exception_text],
|
1749 |
)
|
1750 |
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
1751 |
+
inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output],
|
1752 |
outputs=[text_output, chat_exception_text],
|
1753 |
)
|
1754 |
retry_user_args = dict(fn=functools.partial(user, retry=True),
|
|
|
1766 |
outputs=text_output2,
|
1767 |
)
|
1768 |
bot_args2 = dict(fn=bot,
|
1769 |
+
inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2],
|
1770 |
outputs=[text_output2, chat_exception_text],
|
1771 |
)
|
1772 |
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
1773 |
+
inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2],
|
1774 |
outputs=[text_output2, chat_exception_text],
|
1775 |
)
|
1776 |
retry_user_args2 = dict(fn=functools.partial(user, retry=True),
|
|
|
1791 |
outputs=text_outputs,
|
1792 |
)
|
1793 |
all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
|
1794 |
+
inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs,
|
1795 |
outputs=text_outputs + [chat_exception_text],
|
1796 |
)
|
1797 |
all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
|
1798 |
+
inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs,
|
1799 |
outputs=text_outputs + [chat_exception_text],
|
1800 |
)
|
1801 |
all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
|
|
|
1957 |
def get_short_chat(x, short_chats, short_len=20, words=4):
|
1958 |
if x and len(x[0]) == 2 and x[0][0] is not None:
|
1959 |
short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
|
1960 |
+
if not short_chat:
|
1961 |
+
# e.g.summarization, try using answer
|
1962 |
+
short_chat = ' '.join(x[0][1][:short_len].split(' ')[:words]).strip()
|
1963 |
+
if not short_chat:
|
1964 |
+
short_chat = 'Unk'
|
1965 |
short_chat = dedup(short_chat, short_chats)
|
1966 |
else:
|
1967 |
short_chat = None
|
|
|
2029 |
already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists])
|
2030 |
if not already_exists:
|
2031 |
chat_state1[short_chat] = chat_list.copy()
|
2032 |
+
|
2033 |
+
# reverse so newest at top
|
2034 |
+
choices = list(chat_state1.keys()).copy()
|
2035 |
+
choices.reverse()
|
2036 |
+
|
2037 |
+
return chat_state1, gr.update(choices=choices, value=None)
|
|
|
|
|
2038 |
|
2039 |
def switch_chat(chat_key, chat_state1, num_model_lock=0):
|
2040 |
chosen_chat = chat_state1[chat_key]
|
|
|
2065 |
|
2066 |
remove_chat_event = remove_chat_btn.click(remove_chat,
|
2067 |
inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
|
2068 |
+
queue=False, api_name='remove_chat')
|
2069 |
|
2070 |
def get_chats1(chat_state1):
|
2071 |
base = 'chats'
|
|
|
2096 |
new_chats = json.loads(f.read())
|
2097 |
for chat1_k, chat1_v in new_chats.items():
|
2098 |
# ignore chat1_k, regenerate and de-dup to avoid loss
|
2099 |
+
chat_state1, _ = save_chat(chat1_v, chat_state1, chat_is_list=True)
|
2100 |
except BaseException as e:
|
2101 |
t, v, tb = sys.exc_info()
|
2102 |
ex = ''.join(traceback.format_exception(t, v, tb))
|
|
|
2122 |
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
|
2123 |
.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
2124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
2125 |
clear_event = save_chat_btn.click(save_chat,
|
2126 |
inputs=[text_output, text_output2] + text_outputs + [chat_state],
|
2127 |
+
outputs=[chat_state, radio_chats],
|
2128 |
+
api_name='save_chat' if allow_api else None)
|
2129 |
+
if kwargs['score_model']:
|
2130 |
+
clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
|
|
|
2131 |
|
2132 |
# NOTE: clear of instruction/iinput for nochat has to come after score,
|
2133 |
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
|
2134 |
no_chat_args = dict(fn=fun,
|
2135 |
+
inputs=[model_state, my_db_state, selection_docs_state] + inputs_list,
|
2136 |
outputs=text_output_nochat,
|
2137 |
queue=queue,
|
2138 |
)
|
|
|
2151 |
.then(clear_torch_cache)
|
2152 |
|
2153 |
submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str,
|
2154 |
+
inputs=[model_state, my_db_state, selection_docs_state,
|
2155 |
+
inputs_dict_str],
|
2156 |
outputs=text_output_nochat_api,
|
2157 |
queue=True, # required for generator
|
2158 |
api_name='submit_nochat_api' if allow_api else None) \
|
|
|
2402 |
print("Exception: %s" % str(e), flush=True)
|
2403 |
return json.dumps(sys_dict)
|
2404 |
|
2405 |
+
system_kwargs = all_kwargs.copy()
|
2406 |
+
system_kwargs.update(dict(command=str(' '.join(sys.argv))))
|
2407 |
get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs)
|
2408 |
|
2409 |
system_dict_event = system_btn2.click(get_system_info_dict_func,
|
|
|
2433 |
else:
|
2434 |
tokenizer = None
|
2435 |
if tokenizer is not None:
|
2436 |
+
langchain_mode1 = 'LLM'
|
2437 |
+
add_chat_history_to_context1 = True
|
2438 |
# fake user message to mimic bot()
|
2439 |
chat1 = copy.deepcopy(chat1)
|
2440 |
chat1 = chat1 + [['user_message1', None]]
|
2441 |
model_max_length1 = tokenizer.model_max_length
|
2442 |
+
context1 = history_to_context(chat1, langchain_mode1,
|
2443 |
+
add_chat_history_to_context1,
|
2444 |
+
prompt_type1, prompt_dict1, chat1,
|
2445 |
model_max_length1,
|
2446 |
memory_restriction_level1, keep_sources_in_context1)
|
2447 |
return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
|
|
|
2471 |
,
|
2472 |
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
|
2473 |
|
2474 |
+
demo.load(None, None, None, _js=get_dark_js() if kwargs['dark'] else None)
|
2475 |
|
2476 |
demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
|
2477 |
favicon_path = "h2o-logo.svg"
|
|
|
2486 |
# FIXME: disable for gptj, langchain or gpt4all modify print itself
|
2487 |
# FIXME: and any multi-threaded/async print will enter model output!
|
2488 |
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
2489 |
+
if is_public or os.getenv('PING_GPU'):
|
2490 |
+
scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10)
|
2491 |
scheduler.start()
|
2492 |
|
2493 |
# import control
|
|
|
2506 |
demo.block_thread()
|
2507 |
|
2508 |
|
|
|
|
|
|
|
2509 |
def get_inputs_list(inputs_dict, model_lower, model_id=1):
|
2510 |
"""
|
2511 |
map gradio objects in locals() to inputs for evaluate().
|
|
|
2539 |
return inputs_list, inputs_dict_out
|
2540 |
|
2541 |
|
2542 |
+
def get_sources(db1s, langchain_mode, dbs=None, docs_state0=None):
|
2543 |
+
for k in db1s:
|
2544 |
+
set_userid(db1s[k])
|
2545 |
|
2546 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
2547 |
source_files_added = "NA"
|
|
|
2550 |
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
2551 |
" Ask jon.mckinney@h2o.ai for file if required."
|
2552 |
source_list = []
|
2553 |
+
elif langchain_mode in db1s and len(db1s[langchain_mode]) == 2 and db1s[langchain_mode][0] is not None:
|
2554 |
+
db1 = db1s[langchain_mode]
|
2555 |
from gpt_langchain import get_metadatas
|
2556 |
metadatas = get_metadatas(db1[0])
|
2557 |
source_list = sorted(set([x['source'] for x in metadatas]))
|
|
|
2582 |
db1[1] = str(uuid.uuid4())
|
2583 |
|
2584 |
|
2585 |
+
def update_user_db(file, db1s, selection_docs_state1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs):
|
2586 |
+
kwargs.update(selection_docs_state1)
|
|
|
2587 |
if file is None:
|
2588 |
raise RuntimeError("Don't use change, use input")
|
2589 |
|
2590 |
try:
|
2591 |
+
return _update_user_db(file, db1s=db1s, chunk=chunk, chunk_size=chunk_size,
|
2592 |
langchain_mode=langchain_mode, dbs=dbs,
|
2593 |
**kwargs)
|
2594 |
except BaseException as e:
|
|
|
2619 |
user_id = db1[1]
|
2620 |
base_path = 'locks'
|
2621 |
makedirs(base_path)
|
2622 |
+
lock_file = os.path.join(base_path, "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id))
|
2623 |
return lock_file
|
2624 |
|
2625 |
|
2626 |
def _update_user_db(file,
|
2627 |
+
db1s=None,
|
2628 |
chunk=None, chunk_size=None,
|
2629 |
+
dbs=None, db_type=None,
|
2630 |
+
langchain_mode='UserData',
|
2631 |
+
langchain_modes=None, # unused but required as part of selection_docs_state1
|
2632 |
+
langchain_mode_paths=None,
|
2633 |
+
visible_langchain_modes=None,
|
2634 |
use_openai_embedding=None,
|
2635 |
hf_embedding_model=None,
|
2636 |
caption_loader=None,
|
2637 |
enable_captions=None,
|
2638 |
captions_model=None,
|
2639 |
enable_ocr=None,
|
2640 |
+
enable_pdf_ocr=None,
|
2641 |
verbose=None,
|
2642 |
+
n_jobs=-1,
|
2643 |
is_url=None, is_txt=None,
|
2644 |
+
):
|
2645 |
+
assert db1s is not None
|
2646 |
assert chunk is not None
|
2647 |
assert chunk_size is not None
|
2648 |
assert use_openai_embedding is not None
|
|
|
2651 |
assert enable_captions is not None
|
2652 |
assert captions_model is not None
|
2653 |
assert enable_ocr is not None
|
2654 |
+
assert enable_pdf_ocr is not None
|
2655 |
assert verbose is not None
|
2656 |
|
|
|
|
|
2657 |
if dbs is None:
|
2658 |
dbs = {}
|
2659 |
assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
|
|
|
2671 |
if langchain_mode == LangChainMode.DISABLED.value:
|
2672 |
return None, langchain_mode, get_source_files(), ""
|
2673 |
|
2674 |
+
if langchain_mode in [LangChainMode.LLM.value]:
|
2675 |
# then switch to MyData, so langchain_mode also becomes way to select where upload goes
|
2676 |
# but default to mydata if nothing chosen, since safest
|
2677 |
+
if LangChainMode.MY_DATA.value in visible_langchain_modes:
|
2678 |
+
langchain_mode = LangChainMode.MY_DATA.value
|
2679 |
+
|
2680 |
+
if langchain_mode_paths is None:
|
2681 |
+
langchain_mode_paths = {}
|
2682 |
+
user_path = langchain_mode_paths.get(langchain_mode)
|
2683 |
+
# UserData or custom, which has to be from user's disk
|
2684 |
+
if user_path is not None:
|
2685 |
# move temp files from gradio upload to stable location
|
2686 |
for fili, fil in enumerate(file):
|
2687 |
+
if isinstance(fil, str) and os.path.isfile(fil): # not url, text
|
2688 |
+
new_fil = os.path.normpath(os.path.join(user_path, os.path.basename(fil)))
|
2689 |
+
if os.path.normpath(os.path.abspath(fil)) != os.path.normpath(os.path.abspath(new_fil)):
|
2690 |
if os.path.isfile(new_fil):
|
2691 |
remove(new_fil)
|
2692 |
try:
|
|
|
2706 |
enable_captions=enable_captions,
|
2707 |
captions_model=captions_model,
|
2708 |
enable_ocr=enable_ocr,
|
2709 |
+
enable_pdf_ocr=enable_pdf_ocr,
|
2710 |
caption_loader=caption_loader,
|
2711 |
)
|
2712 |
exceptions = [x for x in sources if x.metadata.get('exception')]
|
2713 |
exceptions_strs = [x.metadata['exception'] for x in exceptions]
|
2714 |
sources = [x for x in sources if 'exception' not in x.metadata]
|
2715 |
|
2716 |
+
# below must at least come after langchain_mode is modified in case was LLM -> MyData,
|
2717 |
+
# so original langchain mode changed
|
2718 |
+
for k in db1s:
|
2719 |
+
set_userid(db1s[k])
|
2720 |
+
db1 = get_db1(db1s, langchain_mode)
|
2721 |
+
|
2722 |
+
lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode) # user-level lock, not db-level lock
|
2723 |
with filelock.FileLock(lock_file):
|
2724 |
+
if langchain_mode in db1s:
|
2725 |
if db1[0] is not None:
|
2726 |
# then add
|
2727 |
db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type,
|
|
|
2731 |
# in testing expect:
|
2732 |
# assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
|
2733 |
# for production hit, when user gets clicky:
|
2734 |
+
assert len(db1) == 2, "Bad %s db: %s" % (langchain_mode, db1)
|
2735 |
+
assert db1[1] is not None, "db hash was None, not allowed"
|
2736 |
# then create
|
2737 |
# if added has to original state and didn't change, then would be shared db for all users
|
2738 |
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
|
|
2754 |
use_openai_embedding=use_openai_embedding,
|
2755 |
hf_embedding_model=hf_embedding_model)
|
2756 |
else:
|
2757 |
+
# then create. Or might just be that dbs is unfilled, then it will fill, then add
|
2758 |
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
2759 |
db_type=db_type,
|
2760 |
persist_directory=persist_directory,
|
|
|
2768 |
return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
|
2769 |
|
2770 |
|
2771 |
+
def get_db(db1s, langchain_mode, dbs=None):
|
2772 |
+
db1 = get_db1(db1s, langchain_mode)
|
2773 |
+
lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode)
|
2774 |
|
2775 |
with filelock.FileLock(lock_file):
|
2776 |
if langchain_mode in ['wiki_full']:
|
2777 |
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
2778 |
db = None
|
2779 |
+
elif langchain_mode in db1s and len(db1) == 2 and db1[0] is not None:
|
2780 |
db = db1[0]
|
2781 |
elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None:
|
2782 |
db = dbs[langchain_mode]
|
|
|
2785 |
return db
|
2786 |
|
2787 |
|
2788 |
+
def get_source_files_given_langchain_mode(db1s, langchain_mode='UserData', dbs=None):
|
2789 |
+
db = get_db(db1s, langchain_mode, dbs=dbs)
|
2790 |
if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
|
2791 |
return "Sources: N/A"
|
2792 |
return get_source_files(db=db, exceptions=None)
|
|
|
2885 |
return source_files_added
|
2886 |
|
2887 |
|
2888 |
+
def update_and_get_source_files_given_langchain_mode(db1s, langchain_mode, chunk, chunk_size,
|
2889 |
+
dbs=None, first_para=None,
|
2890 |
+
text_limit=None,
|
2891 |
+
langchain_mode_paths=None, db_type=None, load_db_if_exists=None,
|
2892 |
n_jobs=None, verbose=None):
|
2893 |
+
has_path = {k: v for k, v in langchain_mode_paths.items() if v}
|
2894 |
+
if langchain_mode in [LangChainMode.LLM.value, LangChainMode.MY_DATA.value]:
|
2895 |
+
# then assume user really meant UserData, to avoid extra clicks in UI,
|
2896 |
+
# since others can't be on disk, except custom user modes, which they should then select to query it
|
2897 |
+
if LangChainMode.USER_DATA.value in has_path:
|
2898 |
+
langchain_mode = LangChainMode.USER_DATA.value
|
2899 |
+
|
2900 |
+
db = get_db(db1s, langchain_mode, dbs=dbs)
|
2901 |
|
2902 |
from gpt_langchain import make_db
|
2903 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
|
|
|
2906 |
chunk=chunk,
|
2907 |
chunk_size=chunk_size,
|
2908 |
langchain_mode=langchain_mode,
|
2909 |
+
langchain_mode_paths=langchain_mode_paths,
|
2910 |
db_type=db_type,
|
2911 |
load_db_if_exists=load_db_if_exists,
|
2912 |
db=db,
|
2913 |
n_jobs=n_jobs,
|
2914 |
verbose=verbose)
|
2915 |
+
# during refreshing, might have "created" new db since not in dbs[] yet, so insert back just in case
|
2916 |
+
# so even if persisted, not kept up-to-date with dbs memory
|
2917 |
+
if langchain_mode in db1s:
|
2918 |
+
db1s[langchain_mode][0] = db
|
2919 |
+
else:
|
2920 |
+
dbs[langchain_mode] = db
|
2921 |
+
|
2922 |
# return only new sources with text saying such
|
2923 |
return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
|
2924 |
+
|
2925 |
+
|
2926 |
+
def get_db1(db1s, langchain_mode1):
|
2927 |
+
if langchain_mode1 in db1s:
|
2928 |
+
db1 = db1s[langchain_mode1]
|
2929 |
+
else:
|
2930 |
+
# indicates to code that not scratch database
|
2931 |
+
db1 = [None, None]
|
2932 |
+
return db1
|
gradio_utils/__init__.py
ADDED
File without changes
|
gradio_utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (134 Bytes). View file
|
|
gradio_utils/__pycache__/css.cpython-310.pyc
CHANGED
Binary files a/gradio_utils/__pycache__/css.cpython-310.pyc and b/gradio_utils/__pycache__/css.cpython-310.pyc differ
|
|
gradio_utils/css.py
CHANGED
@@ -53,4 +53,8 @@ def make_css_base() -> str:
|
|
53 |
margin-bottom: 2.5rem;
|
54 |
}
|
55 |
.chatsmall chatbot {font-size: 10px !important}
|
|
|
|
|
|
|
|
|
56 |
"""
|
|
|
53 |
margin-bottom: 2.5rem;
|
54 |
}
|
55 |
.chatsmall chatbot {font-size: 10px !important}
|
56 |
+
|
57 |
+
.gradio-container {
|
58 |
+
max-width: none !important;
|
59 |
+
}
|
60 |
"""
|
h2oai_pipeline.py
CHANGED
@@ -11,6 +11,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
11 |
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
sanitize_bot_response=False,
|
13 |
use_prompter=True, prompter=None,
|
|
|
14 |
prompt_type=None, prompt_dict=None,
|
15 |
max_input_tokens=2048 - 256, **kwargs):
|
16 |
"""
|
@@ -34,6 +35,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
34 |
self.prompt_type = prompt_type
|
35 |
self.prompt_dict = prompt_dict
|
36 |
self.prompter = prompter
|
|
|
|
|
37 |
if self.use_prompter:
|
38 |
if self.prompter is not None:
|
39 |
assert self.prompter.prompt_type is not None
|
@@ -113,7 +116,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
113 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
114 |
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
115 |
|
116 |
-
data_point = dict(context=
|
117 |
if self.prompter is not None:
|
118 |
prompt_text = self.prompter.generate_prompt(data_point)
|
119 |
self.prompt_text = prompt_text
|
|
|
11 |
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
sanitize_bot_response=False,
|
13 |
use_prompter=True, prompter=None,
|
14 |
+
context='', iinput='',
|
15 |
prompt_type=None, prompt_dict=None,
|
16 |
max_input_tokens=2048 - 256, **kwargs):
|
17 |
"""
|
|
|
35 |
self.prompt_type = prompt_type
|
36 |
self.prompt_dict = prompt_dict
|
37 |
self.prompter = prompter
|
38 |
+
self.context = context
|
39 |
+
self.iinput = iinput
|
40 |
if self.use_prompter:
|
41 |
if self.prompter is not None:
|
42 |
assert self.prompter.prompt_type is not None
|
|
|
116 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
117 |
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
118 |
|
119 |
+
data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput)
|
120 |
if self.prompter is not None:
|
121 |
prompt_text = self.prompter.generate_prompt(data_point)
|
122 |
self.prompt_text = prompt_text
|
iterators/__pycache__/timeout_iterator.cpython-310.pyc
CHANGED
Binary files a/iterators/__pycache__/timeout_iterator.cpython-310.pyc and b/iterators/__pycache__/timeout_iterator.cpython-310.pyc differ
|
|
iterators/timeout_iterator.py
CHANGED
@@ -48,7 +48,7 @@ class TimeoutIterator:
|
|
48 |
def interrupt(self):
|
49 |
"""
|
50 |
interrupt and stop the underlying thread.
|
51 |
-
the thread
|
52 |
the underlying iterator yields a value after that.
|
53 |
"""
|
54 |
self._interrupt = True
|
|
|
48 |
def interrupt(self):
|
49 |
"""
|
50 |
interrupt and stop the underlying thread.
|
51 |
+
the thread actually dies only after interrupt has been set and
|
52 |
the underlying iterator yields a value after that.
|
53 |
"""
|
54 |
self._interrupt = True
|
prompter.py
CHANGED
@@ -77,6 +77,12 @@ prompt_type_to_model_name = {
|
|
77 |
"mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
|
78 |
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
79 |
"falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
81 |
}
|
82 |
if os.getenv('OPENAI_API_KEY'):
|
@@ -596,6 +602,28 @@ ASSISTANT:
|
|
596 |
chat_turn_sep = chat_sep = '\n'
|
597 |
humanstr = PreInstruct
|
598 |
botstr = PreResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
else:
|
600 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
601 |
|
|
|
77 |
"mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
|
78 |
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
79 |
"falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
|
80 |
+
"llama2": [
|
81 |
+
'meta-llama/Llama-2-7b-chat-hf',
|
82 |
+
'meta-llama/Llama-2-13b-chat-hf',
|
83 |
+
'meta-llama/Llama-2-34b-chat-hf',
|
84 |
+
'meta-llama/Llama-2-70b-chat-hf',
|
85 |
+
],
|
86 |
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
87 |
}
|
88 |
if os.getenv('OPENAI_API_KEY'):
|
|
|
602 |
chat_turn_sep = chat_sep = '\n'
|
603 |
humanstr = PreInstruct
|
604 |
botstr = PreResponse
|
605 |
+
elif prompt_type in [PromptType.llama2.value, str(PromptType.llama2.value),
|
606 |
+
PromptType.llama2.name]:
|
607 |
+
PreInstruct = ""
|
608 |
+
llama2_sys = "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
609 |
+
prompt = "<s>[INST] "
|
610 |
+
enable_sys = False # too much safety, hurts accuracy
|
611 |
+
if not (chat and reduced):
|
612 |
+
if enable_sys:
|
613 |
+
promptA = promptB = prompt + llama2_sys
|
614 |
+
else:
|
615 |
+
promptA = promptB = prompt
|
616 |
+
else:
|
617 |
+
promptA = promptB = ''
|
618 |
+
PreInput = None
|
619 |
+
PreResponse = ""
|
620 |
+
terminate_response = ["[INST]", "</s>"]
|
621 |
+
chat_sep = ' [/INST]'
|
622 |
+
chat_turn_sep = ' </s><s>[INST] '
|
623 |
+
humanstr = PreInstruct
|
624 |
+
botstr = PreResponse
|
625 |
+
if making_context:
|
626 |
+
PreResponse += " "
|
627 |
else:
|
628 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
629 |
|
requirements.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
datasets==2.13.0
|
3 |
sentencepiece==0.1.99
|
4 |
-
gradio==3.
|
5 |
-
huggingface_hub==0.
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
docutils==0.20.1
|
@@ -19,7 +19,7 @@ matplotlib==3.7.1
|
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
accelerate==0.20.3
|
22 |
-
|
23 |
transformers==4.30.2
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
@@ -35,7 +35,7 @@ tensorboard==2.13.0
|
|
35 |
neptune==1.2.0
|
36 |
|
37 |
# for gradio client
|
38 |
-
gradio_client==0.2.
|
39 |
beautifulsoup4==4.12.2
|
40 |
markdown==3.4.3
|
41 |
|
@@ -64,8 +64,8 @@ tiktoken==0.4.0
|
|
64 |
# optional: for OpenAI endpoint or embeddings (requires key)
|
65 |
openai==0.27.8
|
66 |
# optional for chat with PDF
|
67 |
-
langchain==0.0.
|
68 |
-
pypdf==3.
|
69 |
# avoid textract, requires old six
|
70 |
#textract==1.6.5
|
71 |
|
@@ -78,10 +78,10 @@ chromadb==0.3.25
|
|
78 |
#pymilvus==2.2.8
|
79 |
|
80 |
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
81 |
-
# unstructured==0.
|
82 |
|
83 |
# strong support for images
|
84 |
-
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
85 |
unstructured[local-inference]==0.7.4
|
86 |
#pdf2image==1.16.3
|
87 |
#pytesseract==0.3.10
|
@@ -104,10 +104,10 @@ tabulate==0.9.0
|
|
104 |
pip-licenses==4.3.0
|
105 |
|
106 |
# weaviate vector db
|
107 |
-
weaviate-client==3.
|
108 |
# optional for chat with PDF
|
109 |
-
langchain==0.0.
|
110 |
-
pypdf==3.
|
111 |
# avoid textract, requires old six
|
112 |
#textract==1.6.5
|
113 |
|
@@ -120,10 +120,10 @@ chromadb==0.3.25
|
|
120 |
#pymilvus==2.2.8
|
121 |
|
122 |
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
123 |
-
# unstructured==0.
|
124 |
|
125 |
# strong support for images
|
126 |
-
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
127 |
unstructured[local-inference]==0.7.4
|
128 |
#pdf2image==1.16.3
|
129 |
#pytesseract==0.3.10
|
@@ -146,8 +146,8 @@ tabulate==0.9.0
|
|
146 |
pip-licenses==4.3.0
|
147 |
|
148 |
# weaviate vector db
|
149 |
-
weaviate-client==3.
|
150 |
faiss-gpu==1.7.2
|
151 |
-
arxiv==1.4.
|
152 |
-
pymupdf==1.22.
|
153 |
# extract-msg==0.41.1 # GPL3
|
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
datasets==2.13.0
|
3 |
sentencepiece==0.1.99
|
4 |
+
gradio==3.37.0
|
5 |
+
huggingface_hub==0.16.4
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
docutils==0.20.1
|
|
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
accelerate==0.20.3
|
22 |
+
peft==0.4.0
|
23 |
transformers==4.30.2
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
|
|
35 |
neptune==1.2.0
|
36 |
|
37 |
# for gradio client
|
38 |
+
gradio_client==0.2.10
|
39 |
beautifulsoup4==4.12.2
|
40 |
markdown==3.4.3
|
41 |
|
|
|
64 |
# optional: for OpenAI endpoint or embeddings (requires key)
|
65 |
openai==0.27.8
|
66 |
# optional for chat with PDF
|
67 |
+
langchain==0.0.235
|
68 |
+
pypdf==3.12.2
|
69 |
# avoid textract, requires old six
|
70 |
#textract==1.6.5
|
71 |
|
|
|
78 |
#pymilvus==2.2.8
|
79 |
|
80 |
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
81 |
+
# unstructured==0.8.1
|
82 |
|
83 |
# strong support for images
|
84 |
+
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
85 |
unstructured[local-inference]==0.7.4
|
86 |
#pdf2image==1.16.3
|
87 |
#pytesseract==0.3.10
|
|
|
104 |
pip-licenses==4.3.0
|
105 |
|
106 |
# weaviate vector db
|
107 |
+
weaviate-client==3.22.1
|
108 |
# optional for chat with PDF
|
109 |
+
langchain==0.0.235
|
110 |
+
pypdf==3.12.2
|
111 |
# avoid textract, requires old six
|
112 |
#textract==1.6.5
|
113 |
|
|
|
120 |
#pymilvus==2.2.8
|
121 |
|
122 |
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
123 |
+
# unstructured==0.8.1
|
124 |
|
125 |
# strong support for images
|
126 |
+
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
127 |
unstructured[local-inference]==0.7.4
|
128 |
#pdf2image==1.16.3
|
129 |
#pytesseract==0.3.10
|
|
|
146 |
pip-licenses==4.3.0
|
147 |
|
148 |
# weaviate vector db
|
149 |
+
weaviate-client==3.22.1
|
150 |
faiss-gpu==1.7.2
|
151 |
+
arxiv==1.4.8
|
152 |
+
pymupdf==1.22.5 # AGPL license
|
153 |
# extract-msg==0.41.1 # GPL3
|
utils.py
CHANGED
@@ -5,6 +5,7 @@ import inspect
|
|
5 |
import os
|
6 |
import gc
|
7 |
import pathlib
|
|
|
8 |
import random
|
9 |
import shutil
|
10 |
import subprocess
|
@@ -111,12 +112,15 @@ def system_info():
|
|
111 |
system = {}
|
112 |
# https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
|
113 |
# https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
|
114 |
-
|
115 |
-
|
116 |
-
coretemp
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
122 |
try:
|
@@ -779,6 +783,9 @@ def _traced_func(func, *args, **kwargs):
|
|
779 |
|
780 |
|
781 |
def call_subprocess_onetask(func, args=None, kwargs=None):
|
|
|
|
|
|
|
782 |
if isinstance(args, list):
|
783 |
args = tuple(args)
|
784 |
if args is None:
|
@@ -1001,3 +1008,73 @@ def set_openai(inference_server):
|
|
1001 |
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
1002 |
inf_type = inference_server
|
1003 |
return openai, inf_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import os
|
6 |
import gc
|
7 |
import pathlib
|
8 |
+
import pickle
|
9 |
import random
|
10 |
import shutil
|
11 |
import subprocess
|
|
|
112 |
system = {}
|
113 |
# https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
|
114 |
# https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
|
115 |
+
try:
|
116 |
+
temps = psutil.sensors_temperatures(fahrenheit=False)
|
117 |
+
if 'coretemp' in temps:
|
118 |
+
coretemp = temps['coretemp']
|
119 |
+
temp_dict = {k.label: k.current for k in coretemp}
|
120 |
+
for k, v in temp_dict.items():
|
121 |
+
system['CPU_C/%s' % k] = v
|
122 |
+
except AttributeError:
|
123 |
+
pass
|
124 |
|
125 |
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
126 |
try:
|
|
|
783 |
|
784 |
|
785 |
def call_subprocess_onetask(func, args=None, kwargs=None):
|
786 |
+
import platform
|
787 |
+
if platform.system() in ['Darwin', 'Windows']:
|
788 |
+
return func(*args, **kwargs)
|
789 |
if isinstance(args, list):
|
790 |
args = tuple(args)
|
791 |
if args is None:
|
|
|
1008 |
openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
|
1009 |
inf_type = inference_server
|
1010 |
return openai, inf_type
|
1011 |
+
|
1012 |
+
|
1013 |
+
visible_langchain_modes_file = 'visible_langchain_modes.pkl'
|
1014 |
+
|
1015 |
+
|
1016 |
+
def save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, db1s):
|
1017 |
+
"""
|
1018 |
+
extra controls if UserData type of MyData type
|
1019 |
+
"""
|
1020 |
+
|
1021 |
+
# use first default MyData hash as general user hash to maintain file
|
1022 |
+
# if user moves MyData from langchain modes, db will still survive, so can still use hash
|
1023 |
+
scratch_collection_names = list(db1s.keys())
|
1024 |
+
user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
|
1025 |
+
|
1026 |
+
llms = ['ChatLLM', 'LLM', 'Disabled']
|
1027 |
+
|
1028 |
+
scratch_langchain_modes = [x for x in langchain_modes if x in scratch_collection_names]
|
1029 |
+
scratch_visible_langchain_modes = [x for x in visible_langchain_modes if x in scratch_collection_names]
|
1030 |
+
scratch_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
|
1031 |
+
k in scratch_collection_names and k not in llms}
|
1032 |
+
|
1033 |
+
user_langchain_modes = [x for x in langchain_modes if x not in scratch_collection_names]
|
1034 |
+
user_visible_langchain_modes = [x for x in visible_langchain_modes if x not in scratch_collection_names]
|
1035 |
+
user_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
|
1036 |
+
k not in scratch_collection_names and k not in llms}
|
1037 |
+
|
1038 |
+
base_path = 'locks'
|
1039 |
+
makedirs(base_path)
|
1040 |
+
|
1041 |
+
# user
|
1042 |
+
extra = ''
|
1043 |
+
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1044 |
+
with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
|
1045 |
+
with open(file, 'wb') as f:
|
1046 |
+
pickle.dump((user_langchain_modes, user_visible_langchain_modes, user_langchain_mode_paths), f)
|
1047 |
+
|
1048 |
+
# scratch
|
1049 |
+
extra = user_hash
|
1050 |
+
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1051 |
+
with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
|
1052 |
+
with open(file, 'wb') as f:
|
1053 |
+
pickle.dump((scratch_langchain_modes, scratch_visible_langchain_modes, scratch_langchain_mode_paths), f)
|
1054 |
+
|
1055 |
+
|
1056 |
+
def load_collection_enum(extra):
|
1057 |
+
"""
|
1058 |
+
extra controls if UserData type of MyData type
|
1059 |
+
"""
|
1060 |
+
file = "%s%s" % (visible_langchain_modes_file, extra)
|
1061 |
+
langchain_modes_from_file = []
|
1062 |
+
visible_langchain_modes_from_file = []
|
1063 |
+
langchain_mode_paths_from_file = {}
|
1064 |
+
if os.path.isfile(visible_langchain_modes_file):
|
1065 |
+
try:
|
1066 |
+
with filelock.FileLock("%s.lock" % file):
|
1067 |
+
with open(file, 'rb') as f:
|
1068 |
+
langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = pickle.load(
|
1069 |
+
f)
|
1070 |
+
except BaseException as e:
|
1071 |
+
print("Cannot load %s, ignoring error: %s" % (file, str(e)), flush=True)
|
1072 |
+
for k, v in langchain_mode_paths_from_file.items():
|
1073 |
+
if v is not None and not os.path.isdir(v) and isinstance(v, str):
|
1074 |
+
# assume was deleted, but need to make again to avoid extra code elsewhere
|
1075 |
+
makedirs(v)
|
1076 |
+
return langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file
|
1077 |
+
|
1078 |
+
|
1079 |
+
def remove_collection_enum():
|
1080 |
+
remove(visible_langchain_modes_file)
|