diff --git "a/gradio_runner.py" "b/gradio_runner.py"
deleted file mode 100644--- "a/gradio_runner.py"
+++ /dev/null
@@ -1,2933 +0,0 @@
-import ast
-import copy
-import functools
-import inspect
-import itertools
-import json
-import os
-import pprint
-import random
-import shutil
-import sys
-import time
-import traceback
-import typing
-import uuid
-import filelock
-import pandas as pd
-import requests
-import tabulate
-from iterators import TimeoutIterator
-
-from gradio_utils.css import get_css
-from gradio_utils.prompt_form import make_chatbots
-
-# This is a hack to prevent Gradio from phoning home when it gets imported
-os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
-
-
-def my_get(url, **kwargs):
- print('Gradio HTTP request redirected to localhost :)', flush=True)
- kwargs.setdefault('allow_redirects', True)
- return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
-
-
-original_get = requests.get
-requests.get = my_get
-import gradio as gr
-
-requests.get = original_get
-
-
-def fix_pydantic_duplicate_validators_error():
- try:
- from pydantic import class_validators
-
- class_validators.in_ipython = lambda: True # type: ignore[attr-defined]
- except ImportError:
- pass
-
-
-fix_pydantic_duplicate_validators_error()
-
-from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \
- DocumentChoice, langchain_modes_intrinsic
-from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
- text_xsm
-from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
- get_prompt
-from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
- ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip, \
- save_collection_names
-from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, scratch_base_dir, \
- get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \
- update_langchain
-from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \
- input_args_list
-
-from apscheduler.schedulers.background import BackgroundScheduler
-
-
-def fix_text_for_gradio(text, fix_new_lines=False, fix_latex_dollars=True):
- if fix_latex_dollars:
- ts = text.split('```')
- for parti, part in enumerate(ts):
- inside = parti % 2 == 1
- if not inside:
- ts[parti] = ts[parti].replace('$', '﹩')
- text = '```'.join(ts)
-
- if fix_new_lines:
- # let Gradio handle code, since got improved recently
- ## FIXME: below conflicts with Gradio, but need to see if can handle multiple \n\n\n etc. properly as is.
- # ensure good visually, else markdown ignores multiple \n
- # handle code blocks
- ts = text.split('```')
- for parti, part in enumerate(ts):
- inside = parti % 2 == 1
- if not inside:
- ts[parti] = ts[parti].replace('\n', '
')
- text = '```'.join(ts)
- return text
-
-
-def go_gradio(**kwargs):
- allow_api = kwargs['allow_api']
- is_public = kwargs['is_public']
- is_hf = kwargs['is_hf']
- memory_restriction_level = kwargs['memory_restriction_level']
- n_gpus = kwargs['n_gpus']
- admin_pass = kwargs['admin_pass']
- model_states = kwargs['model_states']
- dbs = kwargs['dbs']
- db_type = kwargs['db_type']
- visible_langchain_actions = kwargs['visible_langchain_actions']
- visible_langchain_agents = kwargs['visible_langchain_agents']
- allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
- allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
- enable_sources_list = kwargs['enable_sources_list']
- enable_url_upload = kwargs['enable_url_upload']
- enable_text_upload = kwargs['enable_text_upload']
- use_openai_embedding = kwargs['use_openai_embedding']
- hf_embedding_model = kwargs['hf_embedding_model']
- enable_captions = kwargs['enable_captions']
- captions_model = kwargs['captions_model']
- enable_ocr = kwargs['enable_ocr']
- enable_pdf_ocr = kwargs['enable_pdf_ocr']
- caption_loader = kwargs['caption_loader']
-
- # for dynamic state per user session in gradio
- model_state0 = kwargs['model_state0']
- score_model_state0 = kwargs['score_model_state0']
- my_db_state0 = kwargs['my_db_state0']
- selection_docs_state0 = kwargs['selection_docs_state0']
- # for evaluate defaults
- langchain_modes0 = kwargs['langchain_modes']
- visible_langchain_modes0 = kwargs['visible_langchain_modes']
- langchain_mode_paths0 = kwargs['langchain_mode_paths']
-
- # easy update of kwargs needed for evaluate() etc.
- queue = True
- allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
- kwargs.update(locals())
-
- # import control
- if kwargs['langchain_mode'] != 'Disabled':
- from gpt_langchain import file_types, have_arxiv
- else:
- have_arxiv = False
- file_types = []
-
- if 'mbart-' in kwargs['model_lower']:
- instruction_label_nochat = "Text to translate"
- else:
- instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
- " use Enter for multiple input lines)"
-
- title = 'h2oGPT'
- description = """h2oGPT H2O LLM Studio
🤗 Models"""
- description_bottom = "If this host is busy, try
[Multi-Model](https://gpt.h2o.ai)
[Falcon 40B](https://falcon.h2o.ai)
[Vicuna 33B](https://wizardvicuna.h2o.ai)
[MPT 30B-Chat](https://mpt.h2o.ai)
[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)
[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
"
- if is_hf:
- description_bottom += ''''''
- task_info_md = ''
- css_code = get_css(kwargs)
-
- if kwargs['gradio_offline_level'] >= 0:
- # avoid GoogleFont that pulls from internet
- if kwargs['gradio_offline_level'] == 1:
- # front end would still have to download fonts or have cached it at some point
- base_font = 'Source Sans Pro'
- else:
- base_font = 'Helvetica'
- theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'),
- font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'))
- else:
- theme_kwargs = dict()
- if kwargs['gradio_size'] == 'xsmall':
- theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm))
- elif kwargs['gradio_size'] in [None, 'small']:
- theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm,
- radius_size=gr.themes.sizes.spacing_sm))
- elif kwargs['gradio_size'] == 'large':
- theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_lg, text_size=gr.themes.sizes.text_lg),
- radius_size=gr.themes.sizes.spacing_lg)
- elif kwargs['gradio_size'] == 'medium':
- theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_md, text_size=gr.themes.sizes.text_md,
- radius_size=gr.themes.sizes.spacing_md))
-
- theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs)
- demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
- callback = gr.CSVLogger()
-
- model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
- if kwargs['base_model'].strip() not in model_options0:
- model_options0 = [kwargs['base_model'].strip()] + model_options0
- lora_options = kwargs['extra_lora_options']
- if kwargs['lora_weights'].strip() not in lora_options:
- lora_options = [kwargs['lora_weights'].strip()] + lora_options
- server_options = kwargs['extra_server_options']
- if kwargs['inference_server'].strip() not in server_options:
- server_options = [kwargs['inference_server'].strip()] + server_options
- if os.getenv('OPENAI_API_KEY'):
- if 'openai_chat' not in server_options:
- server_options += ['openai_chat']
- if 'openai' not in server_options:
- server_options += ['openai']
-
- # always add in no lora case
- # add fake space so doesn't go away in gradio dropdown
- model_options0 = [no_model_str] + model_options0
- lora_options = [no_lora_str] + lora_options
- server_options = [no_server_str] + server_options
- # always add in no model case so can free memory
- # add fake space so doesn't go away in gradio dropdown
-
- # transcribe, will be detranscribed before use by evaluate()
- if not kwargs['base_model'].strip():
- kwargs['base_model'] = no_model_str
-
- if not kwargs['lora_weights'].strip():
- kwargs['lora_weights'] = no_lora_str
-
- if not kwargs['inference_server'].strip():
- kwargs['inference_server'] = no_server_str
-
- # transcribe for gradio
- kwargs['gpu_id'] = str(kwargs['gpu_id'])
-
- no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
- output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
- 'base_model') else no_model_msg
- output_label0_model2 = no_model_msg
-
- def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0):
- if not prompt_type1 or which_model != 0:
- # keep prompt_type and prompt_dict in sync if possible
- prompt_type1 = kwargs.get('prompt_type', prompt_type1)
- prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
- # prefer model specific prompt type instead of global one
- if not prompt_type1 or which_model != 0:
- prompt_type1 = model_state1.get('prompt_type', prompt_type1)
- prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
-
- if not prompt_dict1 or which_model != 0:
- # if still not defined, try to get
- prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1)
- if not prompt_dict1 or which_model != 0:
- prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1)
- return prompt_type1, prompt_dict1
-
- default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults}
- # ensure prompt_type consistent with prep_bot(), so nochat API works same way
- default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \
- update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'],
- model_state1=model_state0, which_model=0)
- for k in no_default_param_names:
- default_kwargs[k] = ''
-
- def dummy_fun(x):
- # need dummy function to block new input from being sent until output is done,
- # else gets input_list at time of submit that is old, and shows up as truncated in chatbot
- return x
-
- def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
- allow = False
- allow |= langchain_action1 not in LangChainAction.QUERY.value
- allow |= document_subset1 in DocumentSubset.TopKSources.name
- if langchain_mode1 in [LangChainMode.LLM.value]:
- allow = False
- return allow
-
- with demo:
- # avoid actual model/tokenizer here or anything that would be bad to deepcopy
- # https://github.com/gradio-app/gradio/issues/3558
- model_state = gr.State(
- dict(model='model', tokenizer='tokenizer', device=kwargs['device'],
- base_model=kwargs['base_model'],
- tokenizer_base_model=kwargs['tokenizer_base_model'],
- lora_weights=kwargs['lora_weights'],
- inference_server=kwargs['inference_server'],
- prompt_type=kwargs['prompt_type'],
- prompt_dict=kwargs['prompt_dict'],
- )
- )
-
- def update_langchain_mode_paths(db1s, selection_docs_state1):
- if allow_upload_to_my_data:
- selection_docs_state1['langchain_mode_paths'].update({k: None for k in db1s})
- dup = selection_docs_state1['langchain_mode_paths'].copy()
- for k, v in dup.items():
- if k not in selection_docs_state1['visible_langchain_modes']:
- selection_docs_state1['langchain_mode_paths'].pop(k)
- return selection_docs_state1
-
- # Setup some gradio states for per-user dynamic state
- model_state2 = gr.State(kwargs['model_state_none'].copy())
- model_options_state = gr.State([model_options0])
- lora_options_state = gr.State([lora_options])
- server_options_state = gr.State([server_options])
- my_db_state = gr.State(my_db_state0)
- chat_state = gr.State({})
- docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value]
- docs_state0 = []
- [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
- docs_state = gr.State(docs_state0)
- viewable_docs_state0 = []
- viewable_docs_state = gr.State(viewable_docs_state0)
- selection_docs_state0 = update_langchain_mode_paths(my_db_state0, selection_docs_state0)
- selection_docs_state = gr.State(selection_docs_state0)
-
- gr.Markdown(f"""
- {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
- """)
-
- # go button visible if
- base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
- go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
-
- nas = ' '.join(['NA'] * len(kwargs['model_states']))
- res_value = "Response Score: NA" if not kwargs[
- 'model_lock'] else "Response Scores: %s" % nas
-
- if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
- extra_prompt_form = ". For summarization, no query required, just click submit"
- else:
- extra_prompt_form = ""
- if kwargs['input_lines'] > 1:
- instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
- else:
- instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
-
- def get_langchain_choices(selection_docs_state1):
- langchain_modes = selection_docs_state1['langchain_modes']
- visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
-
- if is_hf:
- # don't show 'wiki' since only usually useful for internal testing at moment
- no_show_modes = ['Disabled', 'wiki']
- else:
- no_show_modes = ['Disabled']
- allowed_modes = visible_langchain_modes.copy()
- # allowed_modes = [x for x in allowed_modes if x in dbs]
- allowed_modes += ['LLM']
- if allow_upload_to_my_data and 'MyData' not in allowed_modes:
- allowed_modes += ['MyData']
- if allow_upload_to_user_data and 'UserData' not in allowed_modes:
- allowed_modes += ['UserData']
- choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes]
- return choices
-
- def get_df_langchain_mode_paths(selection_docs_state1):
- langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
- if langchain_mode_paths:
- df = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns')
- df.columns = ['Collection', 'Path']
- else:
- df = pd.DataFrame(None)
- return df
-
- normal_block = gr.Row(visible=not base_wanted, equal_height=False)
- with normal_block:
- side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
- with side_bar:
- with gr.Accordion("Chats", open=False, visible=True):
- radio_chats = gr.Radio(value=None, label="Saved Chats", show_label=False,
- visible=True, interactive=True,
- type='value')
- upload_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload
- with gr.Accordion("Upload", open=False, visible=upload_visible):
- with gr.Column():
- with gr.Row(equal_height=False):
- file_types_str = '[' + ' '.join(file_types) + ' URL ArXiv TEXT' + ']'
- fileup_output = gr.File(label=f'Upload {file_types_str}',
- show_label=False,
- file_types=file_types,
- file_count="multiple",
- scale=1,
- min_width=0,
- elem_id="warning", elem_classes="feedback")
- fileup_output_text = gr.Textbox(visible=False)
- url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
- url_label = 'URL/ArXiv' if have_arxiv else 'URL'
- url_text = gr.Textbox(label=url_label,
- # placeholder="Enter Submits",
- max_lines=1,
- interactive=True)
- text_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload
- user_text_text = gr.Textbox(label='Paste Text',
- # placeholder="Enter Submits",
- interactive=True,
- visible=text_visible)
- github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
- database_visible = kwargs['langchain_mode'] != 'Disabled'
- with gr.Accordion("Resources", open=False, visible=database_visible):
- langchain_choices0 = get_langchain_choices(selection_docs_state0)
- langchain_mode = gr.Radio(
- langchain_choices0,
- value=kwargs['langchain_mode'],
- label="Collections",
- show_label=True,
- visible=kwargs['langchain_mode'] != 'Disabled',
- min_width=100)
- add_chat_history_to_context = gr.Checkbox(label="Chat History",
- value=kwargs['add_chat_history_to_context'])
- document_subset = gr.Radio([x.name for x in DocumentSubset],
- label="Subset",
- value=DocumentSubset.Relevant.name,
- interactive=True,
- )
- allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
- langchain_action = gr.Radio(
- allowed_actions,
- value=allowed_actions[0] if len(allowed_actions) > 0 else None,
- label="Action",
- visible=True)
- allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents]
- langchain_agents = gr.Dropdown(
- langchain_agents_list,
- value=kwargs['langchain_agents'],
- label="Agents",
- multiselect=True,
- interactive=True,
- visible=False) # WIP
- col_tabs = gr.Column(elem_id="col_container", scale=10)
- with (col_tabs, gr.Tabs()):
- with gr.TabItem("Chat"):
- if kwargs['langchain_mode'] == 'Disabled':
- text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True,
- visible=not kwargs['chat'])
- else:
- # text looks a bit worse, but HTML links work
- text_output_nochat = gr.HTML(label=output_label0, visible=not kwargs['chat'])
- with gr.Row():
- # NOCHAT
- instruction_nochat = gr.Textbox(
- lines=kwargs['input_lines'],
- label=instruction_label_nochat,
- placeholder=kwargs['placeholder_instruction'],
- visible=not kwargs['chat'],
- )
- iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
- placeholder=kwargs['placeholder_input'],
- visible=not kwargs['chat'])
- submit_nochat = gr.Button("Submit", size='sm', visible=not kwargs['chat'])
- flag_btn_nochat = gr.Button("Flag", size='sm', visible=not kwargs['chat'])
- score_text_nochat = gr.Textbox("Response Score: NA", show_label=False,
- visible=not kwargs['chat'])
- submit_nochat_api = gr.Button("Submit nochat API", visible=False)
- inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False)
- text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
- show_copy_button=True)
-
- # CHAT
- col_chat = gr.Column(visible=kwargs['chat'])
- with col_chat:
- with gr.Row(): # elem_id='prompt-form-area'):
- with gr.Column(scale=50):
- instruction = gr.Textbox(
- lines=kwargs['input_lines'],
- label='Ask anything',
- placeholder=instruction_label,
- info=None,
- elem_id='prompt-form',
- container=True,
- )
- submit_buttons = gr.Row(equal_height=False)
- with submit_buttons:
- mw1 = 50
- mw2 = 50
- with gr.Column(min_width=mw1):
- submit = gr.Button(value='Submit', variant='primary', size='sm',
- min_width=mw1)
- stop_btn = gr.Button(value="Stop", variant='secondary', size='sm',
- min_width=mw1)
- save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
- with gr.Column(min_width=mw2):
- retry_btn = gr.Button("Redo", size='sm', min_width=mw2)
- undo = gr.Button("Undo", size='sm', min_width=mw2)
- clear_chat_btn = gr.Button(value="Clear", size='sm', min_width=mw2)
- text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2,
- **kwargs)
-
- with gr.Row():
- with gr.Column(visible=kwargs['score_model']):
- score_text = gr.Textbox(res_value,
- show_label=False,
- visible=True)
- score_text2 = gr.Textbox("Response Score2: NA", show_label=False,
- visible=False and not kwargs['model_lock'])
-
- with gr.TabItem("Document Selection"):
- document_choice = gr.Dropdown(docs_state0,
- label="Select Subset of Document(s) %s" % file_types_str,
- value=[DocumentChoice.ALL.value],
- interactive=True,
- multiselect=True,
- visible=kwargs['langchain_mode'] != 'Disabled',
- )
- sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
- with gr.Row():
- with gr.Column(scale=1):
- get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
- visible=sources_visible)
- show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
- visible=sources_visible)
- refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
- size='sm',
- visible=sources_visible and allow_upload_to_user_data)
- with gr.Column(scale=4):
- pass
- with gr.Row():
- with gr.Column(scale=1):
- visible_add_remove_collection = (allow_upload_to_user_data or
- allow_upload_to_my_data) and \
- kwargs['langchain_mode'] != 'Disabled'
- add_placeholder = "e.g. UserData2, user_path2 (optional)" \
- if not is_public else "e.g. MyData2"
- remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
- new_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection,
- label='Add Collection',
- placeholder=add_placeholder,
- interactive=True)
- remove_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection,
- label='Remove Collection',
- placeholder=remove_placeholder,
- interactive=True)
- load_langchain = gr.Button(value="Load LangChain State", scale=0, size='sm',
- visible=allow_upload_to_user_data and
- kwargs['langchain_mode'] != 'Disabled')
- with gr.Column(scale=1):
- df0 = get_df_langchain_mode_paths(selection_docs_state0)
- langchain_mode_path_text = gr.Dataframe(value=df0,
- visible=visible_add_remove_collection,
- label='LangChain Mode-Path',
- show_label=False,
- interactive=False)
- with gr.Column(scale=4):
- pass
-
- sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
- equal_height=False)
- with sources_row:
- with gr.Column(scale=1):
- file_source = gr.File(interactive=False,
- label="Download File w/Sources")
- with gr.Column(scale=2):
- sources_text = gr.HTML(label='Sources Added', interactive=False)
-
- doc_exception_text = gr.Textbox(value="", label='Document Exceptions',
- interactive=False,
- visible=kwargs['langchain_mode'] != 'Disabled')
- with gr.TabItem("Document Viewer"):
- with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled'):
- with gr.Column(scale=2):
- get_viewable_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0,
- size='sm',
- visible=sources_visible)
- view_document_choice = gr.Dropdown(viewable_docs_state0,
- label="Select Single Document",
- value=None,
- interactive=True,
- multiselect=False,
- visible=True,
- )
- with gr.Column(scale=4):
- pass
- document = 'http://infolab.stanford.edu/pub/papers/google.pdf'
- doc_view = gr.HTML(visible=False)
- doc_view2 = gr.Dataframe(visible=False)
- doc_view3 = gr.JSON(visible=False)
- doc_view4 = gr.Markdown(visible=False)
-
- with gr.TabItem("Chat History"):
- with gr.Row():
- with gr.Column(scale=1):
- remove_chat_btn = gr.Button(value="Remove Selected Saved Chats", visible=True, size='sm')
- flag_btn = gr.Button("Flag Current Chat", size='sm')
- export_chats_btn = gr.Button(value="Export Chats to Download", size='sm')
- with gr.Column(scale=4):
- pass
- with gr.Row():
- chats_file = gr.File(interactive=False, label="Download Exported Chats")
- chatsup_output = gr.File(label="Upload Chat File(s)",
- file_types=['.json'],
- file_count='multiple',
- elem_id="warning", elem_classes="feedback")
- with gr.Row():
- if 'mbart-' in kwargs['model_lower']:
- src_lang = gr.Dropdown(list(languages_covered().keys()),
- value=kwargs['src_lang'],
- label="Input Language")
- tgt_lang = gr.Dropdown(list(languages_covered().keys()),
- value=kwargs['tgt_lang'],
- label="Output Language")
-
- chat_exception_text = gr.Textbox(value="", visible=True, label='Chat Exceptions',
- interactive=False)
- with gr.TabItem("Expert"):
- with gr.Row():
- with gr.Column():
- stream_output = gr.components.Checkbox(label="Stream output",
- value=kwargs['stream_output'])
- prompt_type = gr.Dropdown(prompt_types_strings,
- value=kwargs['prompt_type'], label="Prompt Type",
- visible=not kwargs['model_lock'],
- interactive=not is_public,
- )
- prompt_type2 = gr.Dropdown(prompt_types_strings,
- value=kwargs['prompt_type'], label="Prompt Type Model 2",
- visible=False and not kwargs['model_lock'],
- interactive=not is_public)
- do_sample = gr.Checkbox(label="Sample",
- info="Enable sampler, required for use of temperature, top_p, top_k",
- value=kwargs['do_sample'])
- temperature = gr.Slider(minimum=0.01, maximum=2,
- value=kwargs['temperature'],
- label="Temperature",
- info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
- top_p = gr.Slider(minimum=1e-3, maximum=1.0 - 1e-3,
- value=kwargs['top_p'], label="Top p",
- info="Cumulative probability of tokens to sample from")
- top_k = gr.Slider(
- minimum=1, maximum=100, step=1,
- value=kwargs['top_k'], label="Top k",
- info='Num. tokens to sample from'
- )
- # FIXME: https://github.com/h2oai/h2ogpt/issues/106
- if os.getenv('TESTINGFAIL'):
- max_beams = 8 if not (memory_restriction_level or is_public) else 1
- else:
- max_beams = 1
- num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
- value=min(max_beams, kwargs['num_beams']), label="Beams",
- info="Number of searches for optimal overall probability. "
- "Uses more GPU memory/compute",
- interactive=False)
- max_max_new_tokens = get_max_max_new_tokens(model_state0, **kwargs)
- max_new_tokens = gr.Slider(
- minimum=1, maximum=max_max_new_tokens, step=1,
- value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
- )
- min_new_tokens = gr.Slider(
- minimum=0, maximum=max_max_new_tokens, step=1,
- value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
- )
- max_new_tokens2 = gr.Slider(
- minimum=1, maximum=max_max_new_tokens, step=1,
- value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length 2",
- visible=False and not kwargs['model_lock'],
- )
- min_new_tokens2 = gr.Slider(
- minimum=0, maximum=max_max_new_tokens, step=1,
- value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length 2",
- visible=False and not kwargs['model_lock'],
- )
- early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
- value=kwargs['early_stopping'])
- max_time = gr.Slider(minimum=0, maximum=kwargs['max_max_time'], step=1,
- value=min(kwargs['max_max_time'],
- kwargs['max_time']), label="Max. time",
- info="Max. time to search optimal output.")
- repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
- value=kwargs['repetition_penalty'],
- label="Repetition Penalty")
- num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
- value=kwargs['num_return_sequences'],
- label="Number Returns", info="Must be <= num_beams",
- interactive=not is_public)
- iinput = gr.Textbox(lines=4, label="Input",
- placeholder=kwargs['placeholder_input'],
- interactive=not is_public)
- context = gr.Textbox(lines=3, label="System Pre-Context",
- info="Directly pre-appended without prompt processing",
- interactive=not is_public)
- chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
- visible=False, # no longer support nochat in UI
- interactive=not is_public,
- )
- count_chat_tokens_btn = gr.Button(value="Count Chat Tokens",
- visible=not is_public and not kwargs['model_lock'],
- interactive=not is_public)
- chat_token_count = gr.Textbox(label="Chat Token Count", value=None,
- visible=not is_public and not kwargs['model_lock'],
- interactive=False)
- chunk = gr.components.Checkbox(value=kwargs['chunk'],
- label="Whether to chunk documents",
- info="For LangChain",
- visible=kwargs['langchain_mode'] != 'Disabled',
- interactive=not is_public)
- min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public)
- top_k_docs = gr.Slider(minimum=min_top_k_docs, maximum=max_top_k_docs, step=1,
- value=kwargs['top_k_docs'],
- label=label_top_k_docs,
- info="For LangChain",
- visible=kwargs['langchain_mode'] != 'Disabled',
- interactive=not is_public)
- chunk_size = gr.Number(value=kwargs['chunk_size'],
- label="Chunk size for document chunking",
- info="For LangChain (ignored if chunk=False)",
- minimum=128,
- maximum=2048,
- visible=kwargs['langchain_mode'] != 'Disabled',
- interactive=not is_public,
- precision=0)
-
- with gr.TabItem("Models"):
- model_lock_msg = gr.Textbox(lines=1, label="Model Lock Notice",
- placeholder="Started in model_lock mode, no model changes allowed.",
- visible=bool(kwargs['model_lock']), interactive=False)
- load_msg = "Load-Unload Model/LORA [unload works if did not use --base_model]" if not is_public \
- else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
- load_msg2 = "Load-Unload Model/LORA 2 [unload works if did not use --base_model]" if not is_public \
- else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
- variant_load_msg = 'primary' if not is_public else 'secondary'
- compare_checkbox = gr.components.Checkbox(label="Compare Mode",
- value=kwargs['model_lock'],
- visible=not is_public and not kwargs['model_lock'])
- with gr.Row():
- n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
- with gr.Column():
- with gr.Row():
- with gr.Column(scale=20, visible=not kwargs['model_lock']):
- model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
- value=kwargs['base_model'])
- lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
- value=kwargs['lora_weights'], visible=kwargs['show_lora'])
- server_choice = gr.Dropdown(server_options_state.value[0], label="Choose Server",
- value=kwargs['inference_server'], visible=not is_public)
- with gr.Column(scale=1, visible=not kwargs['model_lock']):
- load_model_button = gr.Button(load_msg, variant=variant_load_msg, scale=0,
- size='sm', interactive=not is_public)
- model_load8bit_checkbox = gr.components.Checkbox(
- label="Load 8-bit [requires support]",
- value=kwargs['load_8bit'], interactive=not is_public)
- model_use_gpu_id_checkbox = gr.components.Checkbox(
- label="Choose Devices [If not Checked, use all GPUs]",
- value=kwargs['use_gpu_id'], interactive=not is_public)
- model_gpu = gr.Dropdown(n_gpus_list,
- label="GPU ID [-1 = all GPUs, if Choose is enabled]",
- value=kwargs['gpu_id'], interactive=not is_public)
- model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
- interactive=False)
- lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
- visible=kwargs['show_lora'], interactive=False)
- server_used = gr.Textbox(label="Current Server",
- value=kwargs['inference_server'],
- visible=bool(kwargs['inference_server']) and not is_public,
- interactive=False)
- prompt_dict = gr.Textbox(label="Prompt (or Custom)",
- value=pprint.pformat(kwargs['prompt_dict'], indent=4),
- interactive=not is_public, lines=4)
- col_model2 = gr.Column(visible=False)
- with col_model2:
- with gr.Row():
- with gr.Column(scale=20, visible=not kwargs['model_lock']):
- model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
- value=no_model_str)
- lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
- value=no_lora_str,
- visible=kwargs['show_lora'])
- server_choice2 = gr.Dropdown(server_options_state.value[0], label="Choose Server 2",
- value=no_server_str,
- visible=not is_public)
- with gr.Column(scale=1, visible=not kwargs['model_lock']):
- load_model_button2 = gr.Button(load_msg2, variant=variant_load_msg, scale=0,
- size='sm', interactive=not is_public)
- model_load8bit_checkbox2 = gr.components.Checkbox(
- label="Load 8-bit 2 [requires support]",
- value=kwargs['load_8bit'], interactive=not is_public)
- model_use_gpu_id_checkbox2 = gr.components.Checkbox(
- label="Choose Devices 2 [If not Checked, use all GPUs]",
- value=kwargs[
- 'use_gpu_id'], interactive=not is_public)
- model_gpu2 = gr.Dropdown(n_gpus_list,
- label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
- value=kwargs['gpu_id'], interactive=not is_public)
- # no model/lora loaded ever in model2 by default
- model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str,
- interactive=False)
- lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
- visible=kwargs['show_lora'], interactive=False)
- server_used2 = gr.Textbox(label="Current Server 2", value=no_server_str,
- interactive=False,
- visible=not is_public)
- prompt_dict2 = gr.Textbox(label="Prompt (or Custom) 2",
- value=pprint.pformat(kwargs['prompt_dict'], indent=4),
- interactive=not is_public, lines=4)
- with gr.Row(visible=not kwargs['model_lock']):
- with gr.Column(scale=50):
- new_model = gr.Textbox(label="New Model name/path", interactive=not is_public)
- with gr.Column(scale=50):
- new_lora = gr.Textbox(label="New LORA name/path", visible=kwargs['show_lora'],
- interactive=not is_public)
- with gr.Column(scale=50):
- new_server = gr.Textbox(label="New Server url:port", interactive=not is_public)
- with gr.Row():
- add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0,
- size='sm', interactive=not is_public)
- with gr.TabItem("System"):
- with gr.Row():
- with gr.Column(scale=1):
- side_bar_text = gr.Textbox('on', visible=False, interactive=False)
- submit_buttons_text = gr.Textbox('on', visible=False, interactive=False)
-
- side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
- submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
- col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
- text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400,
- step=50, label='Chat Height')
- dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
- with gr.Column(scale=4):
- pass
- system_visible0 = not is_public and not admin_pass
- admin_row = gr.Row()
- with admin_row:
- with gr.Column(scale=1):
- admin_pass_textbox = gr.Textbox(label="Admin Password", type='password',
- visible=not system_visible0)
- with gr.Column(scale=4):
- pass
- system_row = gr.Row(visible=system_visible0)
- with system_row:
- with gr.Column():
- with gr.Row():
- system_btn = gr.Button(value='Get System Info', size='sm')
- system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True)
- with gr.Row():
- system_input = gr.Textbox(label='System Info Dict Password', interactive=True,
- visible=not is_public)
- system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public, size='sm')
- system_text2 = gr.Textbox(label='System Info Dict', interactive=False,
- visible=not is_public, show_copy_button=True)
- with gr.Row():
- system_btn3 = gr.Button(value='Get Hash', visible=not is_public, size='sm')
- system_text3 = gr.Textbox(label='Hash', interactive=False,
- visible=not is_public, show_copy_button=True)
-
- with gr.Row():
- zip_btn = gr.Button("Zip", size='sm')
- zip_text = gr.Textbox(label="Zip file name", interactive=False)
- file_output = gr.File(interactive=False, label="Zip file to Download")
- with gr.Row():
- s3up_btn = gr.Button("S3UP", size='sm')
- s3up_text = gr.Textbox(label='S3UP result', interactive=False)
-
- with gr.TabItem("Terms of Service"):
- description = ""
- description += """
DISCLAIMERS:
etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation - is_same = True - # length of conversation has to be same - if len(x) != len(y): - return False - if len(x) != len(y): - return False - for stepx, stepy in zip(x, y): - if len(stepx) != len(stepy): - # something off with a conversation - return False - for stepxx, stepyy in zip(stepx, stepy): - if len(stepxx) != len(stepyy): - # something off with a conversation - return False - if len(stepxx) != 2: - # something off - return False - if len(stepyy) != 2: - # something off - return False - questionx = stepxx[0].replace('
', '').replace('
', '') if stepxx[0] is not None else None - answerx = stepxx[1].replace('', '').replace('
', '') if stepxx[1] is not None else None - - questiony = stepyy[0].replace('', '').replace('
', '') if stepyy[0] is not None else None - answery = stepyy[1].replace('', '').replace('
', '') if stepyy[1] is not None else None - - if questionx != questiony or answerx != answery: - return False - return is_same - - def save_chat(*args, chat_is_list=False): - args_list = list(args) - if not chat_is_list: - # list of chatbot histories, - # can't pass in list with list of chatbot histories and state due to gradio limits - chat_list = args_list[:-1] - else: - assert len(args_list) == 2 - chat_list = args_list[0] - # if old chat file with single chatbot, get into shape - if isinstance(chat_list, list) and len(chat_list) > 0 and isinstance(chat_list[0], list) and len( - chat_list[0]) == 2 and isinstance(chat_list[0][0], str) and isinstance(chat_list[0][1], str): - chat_list = [chat_list] - # remove None histories - chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None] - chat_list_none = [x for x in chat_list if x not in chat_list_not_none] - if len(chat_list_none) > 0 and len(chat_list_not_none) == 0: - raise ValueError("Invalid chat file") - # dict with keys of short chat names, values of list of list of chatbot histories - chat_state1 = args_list[-1] - short_chats = list(chat_state1.keys()) - if len(chat_list_not_none) > 0: - # make short_chat key from only first history, based upon question that is same anyways - chat_first = chat_list_not_none[0] - short_chat = get_short_chat(chat_first, short_chats) - if short_chat: - old_chat_lists = list(chat_state1.values()) - already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists]) - if not already_exists: - chat_state1[short_chat] = chat_list.copy() - - # reverse so newest at top - choices = list(chat_state1.keys()).copy() - choices.reverse() - - return chat_state1, gr.update(choices=choices, value=None) - - def switch_chat(chat_key, chat_state1, num_model_lock=0): - chosen_chat = chat_state1[chat_key] - # deal with possible different size of chat list vs. current list - ret_chat = [None] * (2 + num_model_lock) - for chati in range(0, 2 + num_model_lock): - ret_chat[chati % len(ret_chat)] = chosen_chat[chati % len(chosen_chat)] - return tuple(ret_chat) - - def clear_texts(*args): - return tuple([gr.Textbox.update(value='')] * len(args)) - - def clear_scores(): - return gr.Textbox.update(value=res_value), \ - gr.Textbox.update(value='Response Score: NA'), \ - gr.Textbox.update(value='Response Score: NA') - - switch_chat_fun = functools.partial(switch_chat, num_model_lock=len(text_outputs)) - radio_chats.input(switch_chat_fun, - inputs=[radio_chats, chat_state], - outputs=[text_output, text_output2] + text_outputs) \ - .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) - - def remove_chat(chat_key, chat_state1): - if isinstance(chat_key, str): - chat_state1.pop(chat_key, None) - return gr.update(choices=list(chat_state1.keys()), value=None), chat_state1 - - remove_chat_event = remove_chat_btn.click(remove_chat, - inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state], - queue=False, api_name='remove_chat') - - def get_chats1(chat_state1): - base = 'chats' - makedirs(base, exist_ok=True) - filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4())) - with open(filename, "wt") as f: - f.write(json.dumps(chat_state1, indent=2)) - return filename - - export_chat_event = export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False, - api_name='export_chats' if allow_api else None) - - def add_chats_from_file(file, chat_state1, radio_chats1, chat_exception_text1): - if not file: - return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 - if isinstance(file, str): - files = [file] - else: - files = file - if not files: - return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 - chat_exception_list = [] - for file1 in files: - try: - if hasattr(file1, 'name'): - file1 = file1.name - with open(file1, "rt") as f: - new_chats = json.loads(f.read()) - for chat1_k, chat1_v in new_chats.items(): - # ignore chat1_k, regenerate and de-dup to avoid loss - chat_state1, _ = save_chat(chat1_v, chat_state1, chat_is_list=True) - except BaseException as e: - t, v, tb = sys.exc_info() - ex = ''.join(traceback.format_exception(t, v, tb)) - ex_str = "File %s exception: %s" % (file1, str(e)) - print(ex_str, flush=True) - chat_exception_list.append(ex_str) - chat_exception_text1 = '\n'.join(chat_exception_list) - return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 - - # note for update_user_db_func output is ignored for db - chatup_change_event = chatsup_output.change(add_chats_from_file, - inputs=[chatsup_output, chat_state, radio_chats, - chat_exception_text], - outputs=[chatsup_output, chat_state, radio_chats, - chat_exception_text], - queue=False, - api_name='add_to_chats' if allow_api else None) - - clear_chat_event = clear_chat_btn.click(fn=clear_texts, - inputs=[text_output, text_output2] + text_outputs, - outputs=[text_output, text_output2] + text_outputs, - queue=False, api_name='clear' if allow_api else None) \ - .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \ - .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) - - clear_event = save_chat_btn.click(save_chat, - inputs=[text_output, text_output2] + text_outputs + [chat_state], - outputs=[chat_state, radio_chats], - api_name='save_chat' if allow_api else None) - if kwargs['score_model']: - clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) - - # NOTE: clear of instruction/iinput for nochat has to come after score, - # because score for nochat consumes actual textbox, while chat consumes chat history filled by user() - no_chat_args = dict(fn=fun, - inputs=[model_state, my_db_state, selection_docs_state] + inputs_list, - outputs=text_output_nochat, - queue=queue, - ) - submit_event_nochat = submit_nochat.click(**no_chat_args, api_name='submit_nochat' if allow_api else None) \ - .then(clear_torch_cache) \ - .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \ - .then(clear_instruct, None, instruction_nochat) \ - .then(clear_instruct, None, iinput_nochat) \ - .then(clear_torch_cache) - # copy of above with text box submission - submit_event_nochat2 = instruction_nochat.submit(**no_chat_args) \ - .then(clear_torch_cache) \ - .then(**score_args_nochat, queue=queue) \ - .then(clear_instruct, None, instruction_nochat) \ - .then(clear_instruct, None, iinput_nochat) \ - .then(clear_torch_cache) - - submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str, - inputs=[model_state, my_db_state, selection_docs_state, - inputs_dict_str], - outputs=text_output_nochat_api, - queue=True, # required for generator - api_name='submit_nochat_api' if allow_api else None) \ - .then(clear_torch_cache) - - def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit, - use_gpu_id, gpu_id): - # ensure no API calls reach here - if is_public: - raise RuntimeError("Illegal access for %s" % model_name) - # ensure old model removed from GPU memory - if kwargs['debug']: - print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True) - - model0 = model_state0['model'] - if isinstance(model_state_old['model'], str) and model0 is not None: - # best can do, move model loaded at first to CPU - model0.cpu() - - if model_state_old['model'] is not None and not isinstance(model_state_old['model'], str): - try: - model_state_old['model'].cpu() - except Exception as e: - # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data! - print("Unable to put model on CPU: %s" % str(e), flush=True) - del model_state_old['model'] - model_state_old['model'] = None - - if model_state_old['tokenizer'] is not None and not isinstance(model_state_old['tokenizer'], str): - del model_state_old['tokenizer'] - model_state_old['tokenizer'] = None - - clear_torch_cache() - if kwargs['debug']: - print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True) - - if model_name is None or model_name == no_model_str: - # no-op if no model, just free memory - # no detranscribe needed for model, never go into evaluate - lora_weights = no_lora_str - server_name = no_server_str - return [None, None, None, model_name, server_name], \ - model_name, lora_weights, server_name, prompt_type_old, \ - gr.Slider.update(maximum=256), \ - gr.Slider.update(maximum=256) - - # don't deepcopy, can contain model itself - all_kwargs1 = all_kwargs.copy() - all_kwargs1['base_model'] = model_name.strip() - all_kwargs1['load_8bit'] = load_8bit - all_kwargs1['use_gpu_id'] = use_gpu_id - all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe - model_lower = model_name.strip().lower() - if model_lower in inv_prompt_type_to_model_lower: - prompt_type1 = inv_prompt_type_to_model_lower[model_lower] - else: - prompt_type1 = prompt_type_old - - # detranscribe - if lora_weights == no_lora_str: - lora_weights = '' - all_kwargs1['lora_weights'] = lora_weights.strip() - if server_name == no_server_str: - server_name = '' - all_kwargs1['inference_server'] = server_name.strip() - - model1, tokenizer1, device1 = get_model(reward_type=False, - **get_kwargs(get_model, exclude_names=['reward_type'], - **all_kwargs1)) - clear_torch_cache() - - tokenizer_base_model = model_name - prompt_dict1, error0 = get_prompt(prompt_type1, '', - chat=False, context='', reduced=False, making_context=False, - return_dict=True) - model_state_new = dict(model=model1, tokenizer=tokenizer1, device=device1, - base_model=model_name, tokenizer_base_model=tokenizer_base_model, - lora_weights=lora_weights, inference_server=server_name, - prompt_type=prompt_type1, prompt_dict=prompt_dict1, - ) - - max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs) - - if kwargs['debug']: - print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True) - return model_state_new, model_name, lora_weights, server_name, prompt_type1, \ - gr.Slider.update(maximum=max_max_new_tokens1), \ - gr.Slider.update(maximum=max_max_new_tokens1) - - def get_prompt_str(prompt_type1, prompt_dict1, which=0): - if prompt_type1 in ['', None]: - print("Got prompt_type %s: %s" % (which, prompt_type1), flush=True) - return str({}) - prompt_dict1, prompt_dict_error = get_prompt(prompt_type1, prompt_dict1, chat=False, context='', - reduced=False, making_context=False, return_dict=True) - if prompt_dict_error: - return str(prompt_dict_error) - else: - # return so user can manipulate if want and use as custom - return str(prompt_dict1) - - get_prompt_str_func1 = functools.partial(get_prompt_str, which=1) - get_prompt_str_func2 = functools.partial(get_prompt_str, which=2) - prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict], outputs=prompt_dict, queue=False) - prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2, - queue=False) - - def dropdown_prompt_type_list(x): - return gr.Dropdown.update(value=x) - - def chatbot_list(x, model_used_in): - return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]') - - load_model_args = dict(fn=load_model, - inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type, - model_load8bit_checkbox, model_use_gpu_id_checkbox, model_gpu], - outputs=[model_state, model_used, lora_used, server_used, - # if prompt_type changes, prompt_dict will change via change rule - prompt_type, max_new_tokens, min_new_tokens, - ]) - prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type) - chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output) - nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat) - load_model_event = load_model_button.click(**load_model_args, - api_name='load_model' if allow_api and is_public else None) \ - .then(**prompt_update_args) \ - .then(**chatbot_update_args) \ - .then(**nochat_update_args) \ - .then(clear_torch_cache) - - load_model_args2 = dict(fn=load_model, - inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2, - model_load8bit_checkbox2, model_use_gpu_id_checkbox2, model_gpu2], - outputs=[model_state2, model_used2, lora_used2, server_used2, - # if prompt_type2 changes, prompt_dict2 will change via change rule - prompt_type2, max_new_tokens2, min_new_tokens2 - ]) - prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2) - chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2) - load_model_event2 = load_model_button2.click(**load_model_args2, - api_name='load_model2' if allow_api and is_public else None) \ - .then(**prompt_update_args2) \ - .then(**chatbot_update_args2) \ - .then(clear_torch_cache) - - def dropdown_model_lora_server_list(model_list0, model_x, - lora_list0, lora_x, - server_list0, server_x, - model_used1, lora_used1, server_used1, - model_used2, lora_used2, server_used2, - ): - model_new_state = [model_list0[0] + [model_x]] - model_new_options = [*model_new_state[0]] - x1 = model_x if model_used1 == no_model_str else model_used1 - x2 = model_x if model_used2 == no_model_str else model_used2 - ret1 = [gr.Dropdown.update(value=x1, choices=model_new_options), - gr.Dropdown.update(value=x2, choices=model_new_options), - '', model_new_state] - - lora_new_state = [lora_list0[0] + [lora_x]] - lora_new_options = [*lora_new_state[0]] - # don't switch drop-down to added lora if already have model loaded - x1 = lora_x if model_used1 == no_model_str else lora_used1 - x2 = lora_x if model_used2 == no_model_str else lora_used2 - ret2 = [gr.Dropdown.update(value=x1, choices=lora_new_options), - gr.Dropdown.update(value=x2, choices=lora_new_options), - '', lora_new_state] - - server_new_state = [server_list0[0] + [server_x]] - server_new_options = [*server_new_state[0]] - # don't switch drop-down to added server if already have model loaded - x1 = server_x if model_used1 == no_model_str else server_used1 - x2 = server_x if model_used2 == no_model_str else server_used2 - ret3 = [gr.Dropdown.update(value=x1, choices=server_new_options), - gr.Dropdown.update(value=x2, choices=server_new_options), - '', server_new_state] - - return tuple(ret1 + ret2 + ret3) - - add_model_lora_server_event = \ - add_model_lora_server_button.click(fn=dropdown_model_lora_server_list, - inputs=[model_options_state, new_model] + - [lora_options_state, new_lora] + - [server_options_state, new_server] + - [model_used, lora_used, server_used] + - [model_used2, lora_used2, server_used2], - outputs=[model_choice, model_choice2, new_model, model_options_state] + - [lora_choice, lora_choice2, new_lora, lora_options_state] + - [server_choice, server_choice2, new_server, - server_options_state], - queue=False) - - go_event = go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, - queue=False) \ - .then(lambda: gr.update(visible=True), None, normal_block, queue=False) \ - .then(**load_model_args, queue=False).then(**prompt_update_args, queue=False) - - def compare_textbox_fun(x): - return gr.Textbox.update(visible=x) - - def compare_column_fun(x): - return gr.Column.update(visible=x) - - def compare_prompt_fun(x): - return gr.Dropdown.update(visible=x) - - def slider_fun(x): - return gr.Slider.update(visible=x) - - compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, - api_name="compare_checkbox" if allow_api else None) \ - .then(compare_column_fun, compare_checkbox, col_model2) \ - .then(compare_prompt_fun, compare_checkbox, prompt_type2) \ - .then(compare_textbox_fun, compare_checkbox, score_text2) \ - .then(slider_fun, compare_checkbox, max_new_tokens2) \ - .then(slider_fun, compare_checkbox, min_new_tokens2) - # FIXME: add score_res2 in condition, but do better - - # callback for logging flagged input/output - callback.setup(inputs_list + [text_output, text_output2] + text_outputs, "flagged_data_points") - flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2] + text_outputs, - None, - preprocess=False, - api_name='flag' if allow_api else None, queue=False) - flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None, - preprocess=False, - api_name='flag_nochat' if allow_api else None, queue=False) - - def get_system_info(): - if is_public: - time.sleep(10) # delay to avoid spam since queue=False - return gr.Textbox.update(value=system_info_print()) - - system_event = system_btn.click(get_system_info, outputs=system_text, - api_name='system_info' if allow_api else None, queue=False) - - def get_system_info_dict(system_input1, **kwargs1): - if system_input1 != os.getenv("ADMIN_PASS", ""): - return json.dumps({}) - exclude_list = ['admin_pass', 'examples'] - sys_dict = {k: v for k, v in kwargs1.items() if - isinstance(v, (str, int, bool, float)) and k not in exclude_list} - try: - sys_dict.update(system_info()) - except Exception as e: - # protection - print("Exception: %s" % str(e), flush=True) - return json.dumps(sys_dict) - - system_kwargs = all_kwargs.copy() - system_kwargs.update(dict(command=str(' '.join(sys.argv)))) - get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs) - - system_dict_event = system_btn2.click(get_system_info_dict_func, - inputs=system_input, - outputs=system_text2, - api_name='system_info_dict' if allow_api else None, - queue=False, # queue to avoid spam - ) - - def get_hash(): - return kwargs['git_hash'] - - system_event = system_btn3.click(get_hash, - outputs=system_text3, - api_name='system_hash' if allow_api else None, - queue=False, - ) - - def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1, - memory_restriction_level1=0, - keep_sources_in_context1=False, - ): - if model_state1 and not isinstance(model_state1['tokenizer'], str): - tokenizer = model_state1['tokenizer'] - elif model_state0 and not isinstance(model_state0['tokenizer'], str): - tokenizer = model_state0['tokenizer'] - else: - tokenizer = None - if tokenizer is not None: - langchain_mode1 = 'LLM' - add_chat_history_to_context1 = True - # fake user message to mimic bot() - chat1 = copy.deepcopy(chat1) - chat1 = chat1 + [['user_message1', None]] - model_max_length1 = tokenizer.model_max_length - context1 = history_to_context(chat1, langchain_mode1, - add_chat_history_to_context1, - prompt_type1, prompt_dict1, chat1, - model_max_length1, - memory_restriction_level1, keep_sources_in_context1) - return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1]) - else: - return "N/A" - - count_chat_tokens_func = functools.partial(count_chat_tokens, - memory_restriction_level1=memory_restriction_level, - keep_sources_in_context1=kwargs['keep_sources_in_context']) - count_tokens_event = count_chat_tokens_btn.click(fn=count_chat_tokens, - inputs=[model_state, text_output, prompt_type, prompt_dict], - outputs=chat_token_count, - api_name='count_tokens' if allow_api else None) - - # don't pass text_output, don't want to clear output, just stop it - # cancel only stops outer generation, not inner generation or non-generation - stop_btn.click(lambda: None, None, None, - cancels=submits1 + submits2 + submits3 + submits4 + - [submit_event_nochat, submit_event_nochat2] + - [eventdb1, eventdb2, eventdb3] + - [eventdb7, eventdb8, eventdb9, eventdb12] + - db_events + - [clear_event] + - [submit_event_nochat_api, submit_event_nochat] + - [load_model_event, load_model_event2] + - [count_tokens_event] - , - queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False) - - demo.load(None, None, None, _js=get_dark_js() if kwargs['dark'] else None) - - demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open']) - favicon_path = "h2o-logo.svg" - if not os.path.isfile(favicon_path): - print("favicon_path=%s not found" % favicon_path, flush=True) - favicon_path = None - - scheduler = BackgroundScheduler() - scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20) - if is_public and \ - kwargs['base_model'] not in non_hf_types: - # FIXME: disable for gptj, langchain or gpt4all modify print itself - # FIXME: and any multi-threaded/async print will enter model output! - scheduler.add_job(func=ping, trigger="interval", seconds=60) - if is_public or os.getenv('PING_GPU'): - scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10) - scheduler.start() - - # import control - if kwargs['langchain_mode'] == 'Disabled' and \ - os.environ.get("TEST_LANGCHAIN_IMPORT") and \ - kwargs['base_model'] not in non_hf_types: - assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" - assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" - - demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True, - favicon_path=favicon_path, prevent_thread_lock=True, - auth=kwargs['auth']) - if kwargs['verbose']: - print("Started GUI", flush=True) - if kwargs['block_gradio_exit']: - demo.block_thread() - - -def get_inputs_list(inputs_dict, model_lower, model_id=1): - """ - map gradio objects in locals() to inputs for evaluate(). - :param inputs_dict: - :param model_lower: - :param model_id: Which model (1 or 2) of 2 - :return: - """ - inputs_list_names = list(inspect.signature(evaluate).parameters) - inputs_list = [] - inputs_dict_out = {} - for k in inputs_list_names: - if k == 'kwargs': - continue - if k in input_args_list + inputs_kwargs_list: - # these are added at use time for args or partial for kwargs, not taken as input - continue - if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']: - continue - if model_id == 2: - if k == 'prompt_type': - k = 'prompt_type2' - if k == 'prompt_used': - k = 'prompt_used2' - if k == 'max_new_tokens': - k = 'max_new_tokens2' - if k == 'min_new_tokens': - k = 'min_new_tokens2' - inputs_list.append(inputs_dict[k]) - inputs_dict_out[k] = inputs_dict[k] - return inputs_list, inputs_dict_out - - -def get_sources(db1s, langchain_mode, dbs=None, docs_state0=None): - for k in db1s: - set_userid(db1s[k]) - - if langchain_mode in ['LLM']: - source_files_added = "NA" - source_list = [] - elif langchain_mode in ['wiki_full']: - source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \ - " Ask jon.mckinney@h2o.ai for file if required." - source_list = [] - elif langchain_mode in db1s and len(db1s[langchain_mode]) == 2 and db1s[langchain_mode][0] is not None: - db1 = db1s[langchain_mode] - from gpt_langchain import get_metadatas - metadatas = get_metadatas(db1[0]) - source_list = sorted(set([x['source'] for x in metadatas])) - source_files_added = '\n'.join(source_list) - elif langchain_mode in dbs and dbs[langchain_mode] is not None: - from gpt_langchain import get_metadatas - db1 = dbs[langchain_mode] - metadatas = get_metadatas(db1) - source_list = sorted(set([x['source'] for x in metadatas])) - source_files_added = '\n'.join(source_list) - else: - source_list = [] - source_files_added = "None" - sources_dir = "sources_dir" - makedirs(sources_dir) - sources_file = os.path.join(sources_dir, 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))) - with open(sources_file, "wt") as f: - f.write(source_files_added) - source_list = docs_state0 + source_list - return sources_file, source_list - - -def set_userid(db1): - # can only call this after function called so for specific userr, not in gr.State() that occurs during app init - assert db1 is not None and len(db1) == 2 - if db1[1] is None: - # uuid in db is used as user ID - db1[1] = str(uuid.uuid4()) - - -def update_user_db(file, db1s, selection_docs_state1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs): - kwargs.update(selection_docs_state1) - if file is None: - raise RuntimeError("Don't use change, use input") - - try: - return _update_user_db(file, db1s=db1s, chunk=chunk, chunk_size=chunk_size, - langchain_mode=langchain_mode, dbs=dbs, - **kwargs) - except BaseException as e: - print(traceback.format_exc(), flush=True) - # gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox - ex_str = "Exception: %s" % str(e) - source_files_added = """\ - - -
- Sources:
-
- {0}
-
- {0}
-
- Exceptions:
-