#!/usr/bin/env python3 import argparse import logging import typing as t from gradio_ui import build_gradio_ui_for from koboldai_client import run_raw_inference_on_kai, wait_for_kai_server from parsing import parse_messages_from_str from prompting import build_prompt_for from utils import clear_stdout logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # For UI debugging purposes. DONT_USE_MODEL = False def main(server_port: int, share_gradio_link: bool = False, model_name: t.Optional[str] = None, koboldai_url: t.Optional[str] = None) -> None: '''Script entrypoint.''' if model_name and not DONT_USE_MODEL: from model import build_model_and_tokenizer_for, run_raw_inference model, tokenizer = build_model_and_tokenizer_for(model_name) else: model, tokenizer = None, None def inference_fn(history: t.List[str], user_input: str, generation_settings: t.Dict[str, t.Any], *char_settings: t.Any) -> str: if DONT_USE_MODEL: return "Mock response for UI tests." # Brittle. Comes from the order defined in gradio_ui.py. [ char_name, _user_name, char_persona, char_greeting, world_scenario, example_dialogue, ] = char_settings # If we're just starting the conversation and the character has a greeting # configured, return that instead. This is a workaround for the fact that # Gradio assumed that a chatbot cannot possibly start a conversation, so we # can't just have the greeting there automatically, it needs to be in # response to a user message. if len(history) == 0 and char_greeting is not None: return f"{char_name}: {char_greeting}" prompt = build_prompt_for(history=history, user_message=user_input, char_name=char_name, char_persona=char_persona, example_dialogue=example_dialogue, world_scenario=world_scenario) if model and tokenizer: model_output = run_raw_inference(model, tokenizer, prompt, user_input, **generation_settings) elif koboldai_url: model_output = f"{char_name}:" model_output += run_raw_inference_on_kai(koboldai_url, prompt, **generation_settings) else: raise Exception( "Not using local inference, but no Kobold instance URL was" " given. Nowhere to perform inference on.") generated_messages = parse_messages_from_str(model_output, ["You", char_name]) logger.debug("Parsed model response is: `%s`", generated_messages) bot_message = generated_messages[0] return bot_message ui = build_gradio_ui_for(inference_fn, for_kobold=koboldai_url is not None) ui.launch(server_port=server_port, share=share_gradio_link) def _parse_args_from_argv() -> argparse.Namespace: '''Parses arguments coming in from the command line.''' parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model-name", help= "HuggingFace Transformers model name, if not using a KoboldAI instance as an inference server.", ) parser.add_argument( "-p", "--port", type=int, default=3000, help="Port to listen on.", ) parser.add_argument( "-k", "--koboldai-url", help="URL to a KoboldAI instance to use as an inference server.", ) parser.add_argument( "-s", "--share", action="store_true", help="Enable to generate a public link for the Gradio UI.", ) return parser.parse_args() if __name__ == "__main__": args = _parse_args_from_argv() if args.koboldai_url: # I have no idea how long a safe wait time is, but we'd rather wait for # too long rather than just cut the user off _right_ when the setup is # about to finish, so let's pick something absurd here. wait_for_kai_server(args.koboldai_url, max_wait_time_seconds=60 * 30) # Clear out any Kobold logs so the user can clearly see the Gradio link # that's about to show up afterwards. clear_stdout() main(model_name=args.model_name, server_port=args.port, koboldai_url=args.koboldai_url, share_gradio_link=args.share)