Spaces:
Running
Running
pseudotensor
commited on
Commit
•
df5eeb7
1
Parent(s):
935bf6f
Update with h2oGPT hash 23aaa9c9839867b3f0c86e7722cc7fbdae414fc4
Browse files- src/db_utils.py +54 -0
- src/gpt_langchain.py +2 -51
- src/gradio_runner.py +5 -2
src/db_utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
|
3 |
+
from enums import LangChainMode
|
4 |
+
|
5 |
+
|
6 |
+
def set_userid(db1s, requests_state1, get_userid_auth):
|
7 |
+
db1 = db1s[LangChainMode.MY_DATA.value]
|
8 |
+
assert db1 is not None and len(db1) == length_db1()
|
9 |
+
if not db1[1]:
|
10 |
+
db1[1] = get_userid_auth(requests_state1)
|
11 |
+
if not db1[2]:
|
12 |
+
username1 = None
|
13 |
+
if 'username' in requests_state1:
|
14 |
+
username1 = requests_state1['username']
|
15 |
+
db1[2] = username1
|
16 |
+
|
17 |
+
|
18 |
+
def set_userid_direct(db1s, userid, username):
|
19 |
+
db1 = db1s[LangChainMode.MY_DATA.value]
|
20 |
+
db1[1] = userid
|
21 |
+
db1[2] = username
|
22 |
+
|
23 |
+
|
24 |
+
def get_userid_direct(db1s):
|
25 |
+
return db1s[LangChainMode.MY_DATA.value][1] if db1s is not None else ''
|
26 |
+
|
27 |
+
|
28 |
+
def get_username_direct(db1s):
|
29 |
+
return db1s[LangChainMode.MY_DATA.value][2] if db1s is not None else ''
|
30 |
+
|
31 |
+
|
32 |
+
def get_dbid(db1):
|
33 |
+
return db1[1]
|
34 |
+
|
35 |
+
|
36 |
+
def set_dbid(db1):
|
37 |
+
# can only call this after function called so for specific user, not in gr.State() that occurs during app init
|
38 |
+
assert db1 is not None and len(db1) == length_db1()
|
39 |
+
if db1[1] is None:
|
40 |
+
# uuid in db is used as user ID
|
41 |
+
db1[1] = str(uuid.uuid4())
|
42 |
+
|
43 |
+
|
44 |
+
def length_db1():
|
45 |
+
# For MyData:
|
46 |
+
# 0: db
|
47 |
+
# 1: userid and dbid
|
48 |
+
# 2: username
|
49 |
+
|
50 |
+
# For others:
|
51 |
+
# 0: db
|
52 |
+
# 1: dbid
|
53 |
+
# 2: None
|
54 |
+
return 3
|
src/gpt_langchain.py
CHANGED
@@ -37,6 +37,8 @@ from langchain.tools import PythonREPLTool
|
|
37 |
from langchain.tools.json.tool import JsonSpec
|
38 |
from tqdm import tqdm
|
39 |
|
|
|
|
|
40 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
41 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
42 |
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \
|
@@ -4655,57 +4657,6 @@ def get_sources_answer(query, docs, answer, scores, show_rank,
|
|
4655 |
return ret, extra
|
4656 |
|
4657 |
|
4658 |
-
def set_userid(db1s, requests_state1, get_userid_auth):
|
4659 |
-
db1 = db1s[LangChainMode.MY_DATA.value]
|
4660 |
-
assert db1 is not None and len(db1) == length_db1()
|
4661 |
-
if not db1[1]:
|
4662 |
-
db1[1] = get_userid_auth(requests_state1)
|
4663 |
-
if not db1[2]:
|
4664 |
-
username1 = None
|
4665 |
-
if 'username' in requests_state1:
|
4666 |
-
username1 = requests_state1['username']
|
4667 |
-
db1[2] = username1
|
4668 |
-
|
4669 |
-
|
4670 |
-
def set_userid_direct(db1s, userid, username):
|
4671 |
-
db1 = db1s[LangChainMode.MY_DATA.value]
|
4672 |
-
db1[1] = userid
|
4673 |
-
db1[2] = username
|
4674 |
-
|
4675 |
-
|
4676 |
-
def get_userid_direct(db1s):
|
4677 |
-
return db1s[LangChainMode.MY_DATA.value][1] if db1s is not None else ''
|
4678 |
-
|
4679 |
-
|
4680 |
-
def get_username_direct(db1s):
|
4681 |
-
return db1s[LangChainMode.MY_DATA.value][2] if db1s is not None else ''
|
4682 |
-
|
4683 |
-
|
4684 |
-
def get_dbid(db1):
|
4685 |
-
return db1[1]
|
4686 |
-
|
4687 |
-
|
4688 |
-
def set_dbid(db1):
|
4689 |
-
# can only call this after function called so for specific user, not in gr.State() that occurs during app init
|
4690 |
-
assert db1 is not None and len(db1) == length_db1()
|
4691 |
-
if db1[1] is None:
|
4692 |
-
# uuid in db is used as user ID
|
4693 |
-
db1[1] = str(uuid.uuid4())
|
4694 |
-
|
4695 |
-
|
4696 |
-
def length_db1():
|
4697 |
-
# For MyData:
|
4698 |
-
# 0: db
|
4699 |
-
# 1: userid and dbid
|
4700 |
-
# 2: username
|
4701 |
-
|
4702 |
-
# For others:
|
4703 |
-
# 0: db
|
4704 |
-
# 1: dbid
|
4705 |
-
# 2: None
|
4706 |
-
return 3
|
4707 |
-
|
4708 |
-
|
4709 |
def get_any_db(db1s, langchain_mode, langchain_mode_paths, langchain_mode_types,
|
4710 |
dbs=None,
|
4711 |
load_db_if_exists=None, db_type=None,
|
|
|
37 |
from langchain.tools.json.tool import JsonSpec
|
38 |
from tqdm import tqdm
|
39 |
|
40 |
+
from src.db_utils import length_db1, set_dbid, set_userid, get_dbid, get_userid_direct, get_username_direct, \
|
41 |
+
set_userid_direct
|
42 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
43 |
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \
|
44 |
have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_doctr, have_pymupdf, set_openai, \
|
|
|
4657 |
return ret, extra
|
4658 |
|
4659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4660 |
def get_any_db(db1s, langchain_mode, langchain_mode_paths, langchain_mode_types,
|
4661 |
dbs=None,
|
4662 |
load_db_if_exists=None, db_type=None,
|
src/gradio_runner.py
CHANGED
@@ -20,6 +20,7 @@ from iterators import TimeoutIterator
|
|
20 |
|
21 |
from gradio_utils.css import get_css
|
22 |
from gradio_utils.prompt_form import make_chatbots
|
|
|
23 |
|
24 |
# This is a hack to prevent Gradio from phoning home when it gets imported
|
25 |
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
@@ -459,7 +460,6 @@ def go_gradio(**kwargs):
|
|
459 |
if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
|
460 |
requests_state1.update(dict(host2=request.client.host))
|
461 |
if not requests_state1.get('username', '') and hasattr(request, 'username'):
|
462 |
-
from src.gpt_langchain import get_username_direct
|
463 |
# use already-defined username instead of keep changing to new uuid
|
464 |
# should be same as in requests_state1
|
465 |
db_username = get_username_direct(db1s)
|
@@ -469,7 +469,6 @@ def go_gradio(**kwargs):
|
|
469 |
|
470 |
def user_state_setup(db1s, requests_state1, request: gr.Request, *args):
|
471 |
requests_state1 = get_request_state(requests_state1, request, db1s)
|
472 |
-
from src.gpt_langchain import set_userid
|
473 |
set_userid(db1s, requests_state1, get_userid_auth)
|
474 |
args_list = [db1s, requests_state1] + list(args)
|
475 |
return tuple(args_list)
|
@@ -500,6 +499,8 @@ def go_gradio(**kwargs):
|
|
500 |
inference_server=kwargs['inference_server'],
|
501 |
prompt_type=kwargs['prompt_type'],
|
502 |
prompt_dict=kwargs['prompt_dict'],
|
|
|
|
|
503 |
)
|
504 |
)
|
505 |
|
@@ -3746,6 +3747,8 @@ def go_gradio(**kwargs):
|
|
3746 |
base_model=model_name, tokenizer_base_model=tokenizer_base_model,
|
3747 |
lora_weights=lora_weights, inference_server=server_name,
|
3748 |
prompt_type=prompt_type1, prompt_dict=prompt_dict1,
|
|
|
|
|
3749 |
)
|
3750 |
|
3751 |
max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs)
|
|
|
20 |
|
21 |
from gradio_utils.css import get_css
|
22 |
from gradio_utils.prompt_form import make_chatbots
|
23 |
+
from src.db_utils import set_userid, get_username_direct
|
24 |
|
25 |
# This is a hack to prevent Gradio from phoning home when it gets imported
|
26 |
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
|
|
460 |
if not requests_state1.get('host2', '') and hasattr(request, 'client') and hasattr(request.client, 'host'):
|
461 |
requests_state1.update(dict(host2=request.client.host))
|
462 |
if not requests_state1.get('username', '') and hasattr(request, 'username'):
|
|
|
463 |
# use already-defined username instead of keep changing to new uuid
|
464 |
# should be same as in requests_state1
|
465 |
db_username = get_username_direct(db1s)
|
|
|
469 |
|
470 |
def user_state_setup(db1s, requests_state1, request: gr.Request, *args):
|
471 |
requests_state1 = get_request_state(requests_state1, request, db1s)
|
|
|
472 |
set_userid(db1s, requests_state1, get_userid_auth)
|
473 |
args_list = [db1s, requests_state1] + list(args)
|
474 |
return tuple(args_list)
|
|
|
499 |
inference_server=kwargs['inference_server'],
|
500 |
prompt_type=kwargs['prompt_type'],
|
501 |
prompt_dict=kwargs['prompt_dict'],
|
502 |
+
visible_models=kwargs['visible_models'],
|
503 |
+
h2ogpt_key=kwargs['h2ogpt_key'],
|
504 |
)
|
505 |
)
|
506 |
|
|
|
3747 |
base_model=model_name, tokenizer_base_model=tokenizer_base_model,
|
3748 |
lora_weights=lora_weights, inference_server=server_name,
|
3749 |
prompt_type=prompt_type1, prompt_dict=prompt_dict1,
|
3750 |
+
# FIXME: not typically required, unless want to expose adding h2ogpt endpoint in UI
|
3751 |
+
visible_models=None, h2ogpt_key=None,
|
3752 |
)
|
3753 |
|
3754 |
max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs)
|