Spaces:
Runtime error
Runtime error
File size: 4,744 Bytes
bddc905 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
#!/usr/bin/env python3
import argparse
import logging
import typing as t
from gradio_ui import build_gradio_ui_for
from koboldai_client import run_raw_inference_on_kai, wait_for_kai_server
from parsing import parse_messages_from_str
from prompting import build_prompt_for
from utils import clear_stdout
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# For UI debugging purposes.
DONT_USE_MODEL = False
def main(server_port: int,
share_gradio_link: bool = False,
model_name: t.Optional[str] = None,
koboldai_url: t.Optional[str] = None) -> None:
'''Script entrypoint.'''
if model_name and not DONT_USE_MODEL:
from model import build_model_and_tokenizer_for, run_raw_inference
model, tokenizer = build_model_and_tokenizer_for(model_name)
else:
model, tokenizer = None, None
def inference_fn(history: t.List[str], user_input: str,
generation_settings: t.Dict[str, t.Any],
*char_settings: t.Any) -> str:
if DONT_USE_MODEL:
return "Mock response for UI tests."
# Brittle. Comes from the order defined in gradio_ui.py.
[
char_name,
_user_name,
char_persona,
char_greeting,
world_scenario,
example_dialogue,
] = char_settings
# If we're just starting the conversation and the character has a greeting
# configured, return that instead. This is a workaround for the fact that
# Gradio assumed that a chatbot cannot possibly start a conversation, so we
# can't just have the greeting there automatically, it needs to be in
# response to a user message.
if len(history) == 0 and char_greeting is not None:
return f"{char_name}: {char_greeting}"
prompt = build_prompt_for(history=history,
user_message=user_input,
char_name=char_name,
char_persona=char_persona,
example_dialogue=example_dialogue,
world_scenario=world_scenario)
if model and tokenizer:
model_output = run_raw_inference(model, tokenizer, prompt,
user_input, **generation_settings)
elif koboldai_url:
model_output = f"{char_name}:"
model_output += run_raw_inference_on_kai(koboldai_url, prompt,
**generation_settings)
else:
raise Exception(
"Not using local inference, but no Kobold instance URL was"
" given. Nowhere to perform inference on.")
generated_messages = parse_messages_from_str(model_output,
["You", char_name])
logger.debug("Parsed model response is: `%s`", generated_messages)
bot_message = generated_messages[0]
return bot_message
ui = build_gradio_ui_for(inference_fn, for_kobold=koboldai_url is not None)
ui.launch(server_port=server_port, share=share_gradio_link)
def _parse_args_from_argv() -> argparse.Namespace:
'''Parses arguments coming in from the command line.'''
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model-name",
help=
"HuggingFace Transformers model name, if not using a KoboldAI instance as an inference server.",
)
parser.add_argument(
"-p",
"--port",
type=int,
default=3000,
help="Port to listen on.",
)
parser.add_argument(
"-k",
"--koboldai-url",
help="URL to a KoboldAI instance to use as an inference server.",
)
parser.add_argument(
"-s",
"--share",
action="store_true",
help="Enable to generate a public link for the Gradio UI.",
)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args_from_argv()
if args.koboldai_url:
# I have no idea how long a safe wait time is, but we'd rather wait for
# too long rather than just cut the user off _right_ when the setup is
# about to finish, so let's pick something absurd here.
wait_for_kai_server(args.koboldai_url, max_wait_time_seconds=60 * 30)
# Clear out any Kobold logs so the user can clearly see the Gradio link
# that's about to show up afterwards.
clear_stdout()
main(model_name=args.model_name,
server_port=args.port,
koboldai_url=args.koboldai_url,
share_gradio_link=args.share)
|