File size: 4,744 Bytes
bddc905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
import argparse
import logging
import typing as t

from gradio_ui import build_gradio_ui_for
from koboldai_client import run_raw_inference_on_kai, wait_for_kai_server
from parsing import parse_messages_from_str
from prompting import build_prompt_for
from utils import clear_stdout

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# For UI debugging purposes.
DONT_USE_MODEL = False


def main(server_port: int,
         share_gradio_link: bool = False,
         model_name: t.Optional[str] = None,
         koboldai_url: t.Optional[str] = None) -> None:
    '''Script entrypoint.'''
    if model_name and not DONT_USE_MODEL:
        from model import build_model_and_tokenizer_for, run_raw_inference
        model, tokenizer = build_model_and_tokenizer_for(model_name)
    else:
        model, tokenizer = None, None

    def inference_fn(history: t.List[str], user_input: str,
                     generation_settings: t.Dict[str, t.Any],
                     *char_settings: t.Any) -> str:
        if DONT_USE_MODEL:
            return "Mock response for UI tests."

        # Brittle. Comes from the order defined in gradio_ui.py.
        [
            char_name,
            _user_name,
            char_persona,
            char_greeting,
            world_scenario,
            example_dialogue,
        ] = char_settings

        # If we're just starting the conversation and the character has a greeting
        # configured, return that instead. This is a workaround for the fact that
        # Gradio assumed that a chatbot cannot possibly start a conversation, so we
        # can't just have the greeting there automatically, it needs to be in
        # response to a user message.
        if len(history) == 0 and char_greeting is not None:
            return f"{char_name}: {char_greeting}"

        prompt = build_prompt_for(history=history,
                                  user_message=user_input,
                                  char_name=char_name,
                                  char_persona=char_persona,
                                  example_dialogue=example_dialogue,
                                  world_scenario=world_scenario)

        if model and tokenizer:
            model_output = run_raw_inference(model, tokenizer, prompt,
                                             user_input, **generation_settings)
        elif koboldai_url:
            model_output = f"{char_name}:"
            model_output += run_raw_inference_on_kai(koboldai_url, prompt,
                                                     **generation_settings)
        else:
            raise Exception(
                "Not using local inference, but no Kobold instance URL was"
                " given. Nowhere to perform inference on.")

        generated_messages = parse_messages_from_str(model_output,
                                                     ["You", char_name])
        logger.debug("Parsed model response is: `%s`", generated_messages)
        bot_message = generated_messages[0]

        return bot_message

    ui = build_gradio_ui_for(inference_fn, for_kobold=koboldai_url is not None)
    ui.launch(server_port=server_port, share=share_gradio_link)


def _parse_args_from_argv() -> argparse.Namespace:
    '''Parses arguments coming in from the command line.'''
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model-name",
        help=
        "HuggingFace Transformers model name, if not using a KoboldAI instance as an inference server.",
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default=3000,
        help="Port to listen on.",
    )
    parser.add_argument(
        "-k",
        "--koboldai-url",
        help="URL to a KoboldAI instance to use as an inference server.",
    )
    parser.add_argument(
        "-s",
        "--share",
        action="store_true",
        help="Enable to generate a public link for the Gradio UI.",
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = _parse_args_from_argv()

    if args.koboldai_url:
        # I have no idea how long a safe wait time is, but we'd rather wait for
        # too long rather than just cut the user off _right_ when the setup is
        # about to finish, so let's pick something absurd here.
        wait_for_kai_server(args.koboldai_url, max_wait_time_seconds=60 * 30)

        # Clear out any Kobold logs so the user can clearly see the Gradio link
        # that's about to show up afterwards.
        clear_stdout()

    main(model_name=args.model_name,
         server_port=args.port,
         koboldai_url=args.koboldai_url,
         share_gradio_link=args.share)