gradiopyg / src /app.py
dorkai's picture
Upload 13 files
bddc905
raw
history blame
4.74 kB
#!/usr/bin/env python3
import argparse
import logging
import typing as t
from gradio_ui import build_gradio_ui_for
from koboldai_client import run_raw_inference_on_kai, wait_for_kai_server
from parsing import parse_messages_from_str
from prompting import build_prompt_for
from utils import clear_stdout
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# For UI debugging purposes.
DONT_USE_MODEL = False
def main(server_port: int,
share_gradio_link: bool = False,
model_name: t.Optional[str] = None,
koboldai_url: t.Optional[str] = None) -> None:
'''Script entrypoint.'''
if model_name and not DONT_USE_MODEL:
from model import build_model_and_tokenizer_for, run_raw_inference
model, tokenizer = build_model_and_tokenizer_for(model_name)
else:
model, tokenizer = None, None
def inference_fn(history: t.List[str], user_input: str,
generation_settings: t.Dict[str, t.Any],
*char_settings: t.Any) -> str:
if DONT_USE_MODEL:
return "Mock response for UI tests."
# Brittle. Comes from the order defined in gradio_ui.py.
[
char_name,
_user_name,
char_persona,
char_greeting,
world_scenario,
example_dialogue,
] = char_settings
# If we're just starting the conversation and the character has a greeting
# configured, return that instead. This is a workaround for the fact that
# Gradio assumed that a chatbot cannot possibly start a conversation, so we
# can't just have the greeting there automatically, it needs to be in
# response to a user message.
if len(history) == 0 and char_greeting is not None:
return f"{char_name}: {char_greeting}"
prompt = build_prompt_for(history=history,
user_message=user_input,
char_name=char_name,
char_persona=char_persona,
example_dialogue=example_dialogue,
world_scenario=world_scenario)
if model and tokenizer:
model_output = run_raw_inference(model, tokenizer, prompt,
user_input, **generation_settings)
elif koboldai_url:
model_output = f"{char_name}:"
model_output += run_raw_inference_on_kai(koboldai_url, prompt,
**generation_settings)
else:
raise Exception(
"Not using local inference, but no Kobold instance URL was"
" given. Nowhere to perform inference on.")
generated_messages = parse_messages_from_str(model_output,
["You", char_name])
logger.debug("Parsed model response is: `%s`", generated_messages)
bot_message = generated_messages[0]
return bot_message
ui = build_gradio_ui_for(inference_fn, for_kobold=koboldai_url is not None)
ui.launch(server_port=server_port, share=share_gradio_link)
def _parse_args_from_argv() -> argparse.Namespace:
'''Parses arguments coming in from the command line.'''
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model-name",
help=
"HuggingFace Transformers model name, if not using a KoboldAI instance as an inference server.",
)
parser.add_argument(
"-p",
"--port",
type=int,
default=3000,
help="Port to listen on.",
)
parser.add_argument(
"-k",
"--koboldai-url",
help="URL to a KoboldAI instance to use as an inference server.",
)
parser.add_argument(
"-s",
"--share",
action="store_true",
help="Enable to generate a public link for the Gradio UI.",
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args_from_argv()
if args.koboldai_url:
# I have no idea how long a safe wait time is, but we'd rather wait for
# too long rather than just cut the user off _right_ when the setup is
# about to finish, so let's pick something absurd here.
wait_for_kai_server(args.koboldai_url, max_wait_time_seconds=60 * 30)
# Clear out any Kobold logs so the user can clearly see the Gradio link
# that's about to show up afterwards.
clear_stdout()
main(model_name=args.model_name,
server_port=args.port,
koboldai_url=args.koboldai_url,
share_gradio_link=args.share)