File size: 18,377 Bytes
bddc905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
import json
import logging

import gradio as gr


def get_generation_defaults(for_kobold):
    defaults = {
        "do_sample": True,
        "max_new_tokens": 196,
        "temperature": 0.5,
        "top_p": 0.9,
        "top_k": 0,
        "typical_p": 1.0,
        "repetition_penalty": 1.05,
    }

    if for_kobold:
        defaults.update({"max_context_length": 768})
    else:
        defaults.update({"penalty_alpha": 0.6})

    return defaults


logger = logging.getLogger(__name__)


def build_gradio_ui_for(inference_fn, for_kobold):
    '''
    Builds a Gradio UI to interact with the model. Big thanks to TearGosling for
    the initial version that inspired this.
    '''
    with gr.Blocks(title="Pygmalion", analytics_enabled=False) as interface:
        history_for_gradio = gr.State([])
        history_for_model = gr.State([])
        generation_settings = gr.State(
            get_generation_defaults(for_kobold=for_kobold))

        def _update_generation_settings(
            original_settings,
            param_name,
            new_value,
        ):
            '''
            Merges `{param_name: new_value}` into `original_settings` and
            returns a new dictionary.
            '''
            updated_settings = {**original_settings, param_name: new_value}
            logging.debug("Generation settings updated to: `%s`",
                         updated_settings)
            return updated_settings

        def _run_inference(
            model_history,
            gradio_history,
            user_input,
            generation_settings,
            *char_setting_states,
        ):
            '''
            Runs inference on the model, and formats the returned response for
            the Gradio state and chatbot component.
            '''
            char_name = char_setting_states[0]
            user_name = char_setting_states[1]

            # If user input is blank, format it as if user was silent
            if user_input is None or user_input.strip() == "":
                user_input = "..."

            inference_result = inference_fn(model_history, user_input,
                                            generation_settings,
                                            *char_setting_states)

            inference_result_for_gradio = inference_result \
                .replace(f"{char_name}:", f"**{char_name}:**") \
                .replace("<USER>", user_name) \
                .replace("\n", "<br>") # Gradio chatbot component can display br tag as linebreak

            model_history.append(f"You: {user_input}")
            model_history.append(inference_result)
            gradio_history.append((user_input, inference_result_for_gradio))

            return None, model_history, gradio_history, gradio_history

        def _regenerate(
            model_history,
            gradio_history,
            generation_settings,
            *char_setting_states,
        ):
            '''Regenerates the last response.'''
            return _run_inference(
                model_history[:-2],
                gradio_history[:-1],
                model_history[-2].replace("You: ", ""),
                generation_settings,
                *char_setting_states,
            )

        def _undo_last_exchange(model_history, gradio_history):
            '''Undoes the last exchange (message pair).'''
            return model_history[:-2], gradio_history[:-1], gradio_history[:-1]

        def _save_chat_history(model_history, *char_setting_states):
            '''Saves the current chat history to a .json file.'''
            char_name = char_setting_states[0]
            with open(f"{char_name}_conversation.json", "w") as f:
                f.write(json.dumps({"chat": model_history}))
            return f"{char_name}_conversation.json"

        def _load_chat_history(file_obj, *char_setting_states):
            '''Loads up a chat history from a .json file.'''
            # #############################################################################################
            # TODO(TG): Automatically detect and convert any CAI dump files loaded in to Pygmalion format #
            # #############################################################################################

            # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
            def pairwise(iterable):
                # "s -> (s0, s1), (s2, s3), (s4, s5), ..."
                a = iter(iterable)
                return zip(a, a)

            char_name = char_setting_states[0]
            user_name = char_setting_states[1]

            file_data = json.loads(file_obj.decode('utf-8'))
            model_history = file_data["chat"]
            # Construct a new gradio history
            new_gradio_history = []
            for human_turn, bot_turn in pairwise(model_history):
                # Handle the situation where convo history may be loaded before character defs
                if char_name == "":
                    # Grab char name from the model history
                    char_name = bot_turn.split(":")[0]
                # Format the user and bot utterances
                user_turn = human_turn.replace("You: ", "")
                bot_turn = bot_turn.replace(f"{char_name}:", f"**{char_name}:**")

                # Somebody released a script on /g/ which tries to convert CAI dump logs
                # to Pygmalion character settings and chats. The anonymization of the dumps, however, means that
                # [NAME_IN_MESSAGE_REDACTED] is left in the conversational history. We obviously wouldn't want this
                # This therefore accomodates users of that script, so that [NAME_IN_MESSAGE_REDACTED] doesn't have
                # to be manually edited in the conversation JSON.
                # The model shouldn't generate [NAME_IN_MESSAGE_REDACTED] by itself.
                user_turn = user_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name)
                bot_turn = bot_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name)

                new_gradio_history.append((user_turn, bot_turn))

            return model_history, new_gradio_history, new_gradio_history

        with gr.Tab("Character Settings") as settings_tab:
            charfile, char_setting_states = _build_character_settings_ui()

        with gr.Tab("Chat Window"):
            chatbot = gr.Chatbot(
                label="Your conversation will show up here").style(
                    color_map=("#326efd", "#212528"))

            char_name, _user_name, char_persona, char_greeting, world_scenario, example_dialogue = char_setting_states
            charfile.upload(
                fn=_char_file_upload,
                inputs=[charfile, history_for_model, history_for_gradio],
                outputs=[history_for_model, history_for_gradio, chatbot, char_name, char_persona, char_greeting, world_scenario, example_dialogue]
            )

            message = gr.Textbox(
                label="Your message (hit Enter to send)",
                placeholder="Write a message...",
            )
            message.submit(
                fn=_run_inference,
                inputs=[
                    history_for_model, history_for_gradio, message,
                    generation_settings, *char_setting_states
                ],
                outputs=[
                    message, history_for_model, history_for_gradio, chatbot
                ],
            )

            with gr.Row():
                send_btn = gr.Button("Send", variant="primary")
                send_btn.click(
                    fn=_run_inference,
                    inputs=[
                        history_for_model, history_for_gradio, message,
                        generation_settings, *char_setting_states
                    ],
                    outputs=[
                        message, history_for_model, history_for_gradio, chatbot
                    ],
                )

                regenerate_btn = gr.Button("Regenerate")
                regenerate_btn.click(
                    fn=_regenerate,
                    inputs=[
                        history_for_model, history_for_gradio,
                        generation_settings, *char_setting_states
                    ],
                    outputs=[
                        message, history_for_model, history_for_gradio, chatbot
                    ],
                )

                undo_btn = gr.Button("Undo last exchange")
                undo_btn.click(
                    fn=_undo_last_exchange,
                    inputs=[history_for_model, history_for_gradio],
                    outputs=[history_for_model, history_for_gradio, chatbot],
                )

            with gr.Row():
                with gr.Column():
                    chatfile = gr.File(type="binary", file_types=[".json"], interactive=True)
                    chatfile.upload(
                        fn=_load_chat_history,
                        inputs=[chatfile, *char_setting_states],
                        outputs=[history_for_model, history_for_gradio, chatbot]
                    )

                    save_char_btn = gr.Button(value="Save Conversation History")
                    save_char_btn.click(_save_chat_history, inputs=[history_for_model, *char_setting_states], outputs=[chatfile])
                with gr.Column():
                    gr.Markdown("""
                        ### To save a chat
                        Click "Save Conversation History". The file will appear above the button and you can click to download it.

                        ### To load a chat
                        Drag a valid .json file onto the upload box, or click the box to browse.

                        **Remember to fill out/load up your character definitions before resuming a chat!**
                    """)



        with gr.Tab("Generation Settings"):
            _build_generation_settings_ui(
                state=generation_settings,
                fn=_update_generation_settings,
                for_kobold=for_kobold,
            )

    return interface


def _char_file_upload(file_obj, history_model, history_gradio):
    file_data = json.loads(file_obj.decode('utf-8'))
    char_name = file_data["char_name"]
    greeting = file_data["char_greeting"]
    empty_history = not history_model or (len(history_model) <= 2 and history_model[0] == '')
    if empty_history and char_name and greeting:
        # if chat history is empty so far, and there is a character greeting, add character greeting to the chat
        s = f'{char_name}: {greeting}'
        t = f'**{char_name}**: {greeting}'
        history_model = ['', s]
        history_gradio = [('', t)]
    return history_model, history_gradio, history_gradio, char_name, file_data["char_persona"], greeting, file_data["world_scenario"], file_data["example_dialogue"]

def _build_character_settings_ui():

    def char_file_create(char_name, char_persona, char_greeting, world_scenario, example_dialogue):
        with open(char_name + ".json", "w") as f:
            f.write(json.dumps({"char_name": char_name, "char_persona": char_persona, "char_greeting": char_greeting, "world_scenario": world_scenario, "example_dialogue": example_dialogue}))
        return char_name + ".json"

    with gr.Column():
        with gr.Row():
            char_name = gr.Textbox(
                label="Character Name",
                placeholder="The character's name",
            )
            user_name = gr.Textbox(
                label="Your Name",
                placeholder="How the character should call you",
            )

        char_persona = gr.Textbox(
            label="Character Persona",
            placeholder=
            "Describe the character's persona here. Think of this as CharacterAI's description + definitions in one box.",
            lines=4,
        )
        char_greeting = gr.Textbox(
            label="Character Greeting",
            placeholder=
            "Write the character's greeting here. They will say this verbatim as their first response.",
            lines=3,
        )

        world_scenario = gr.Textbox(
            label="Scenario",
            placeholder=
            "Optionally, describe the starting scenario in a few short sentences.",
        )
        example_dialogue = gr.Textbox(
            label="Example Chat",
            placeholder=
            "Optionally, write in an example chat here. This is useful for showing how the character should behave, for example.",
            lines=4,
        )

        with gr.Row():
            with gr.Column():
                charfile = gr.File(type="binary", file_types=[".json"])

                save_char_btn = gr.Button(value="Generate Character File")
                save_char_btn.click(char_file_create, inputs=[char_name, char_persona, char_greeting, world_scenario, example_dialogue], outputs=[charfile])
            with gr.Column():
                gr.Markdown("""
                    ### To save a character
                    Click "Generate Character File". The file will appear above the button and you can click to download it.

                    ### To upload a character
                    Drag a valid .json file onto the upload box, or click the box to browse.
                """)

    return charfile, (char_name, user_name, char_persona, char_greeting, world_scenario, example_dialogue)


def _build_generation_settings_ui(state, fn, for_kobold):
    generation_defaults = get_generation_defaults(for_kobold=for_kobold)

    with gr.Row():
        with gr.Column():
            max_new_tokens = gr.Slider(
                16,
                512,
                value=generation_defaults["max_new_tokens"],
                step=4,
                label="max_new_tokens",
            )
            max_new_tokens.change(
                lambda state, value: fn(state, "max_new_tokens", value),
                inputs=[state, max_new_tokens],
                outputs=state,
            )

            temperature = gr.Slider(
                0.1,
                2,
                value=generation_defaults["temperature"],
                step=0.01,
                label="temperature",
            )
            temperature.change(
                lambda state, value: fn(state, "temperature", value),
                inputs=[state, temperature],
                outputs=state,
            )

            top_p = gr.Slider(
                0.0,
                1.0,
                value=generation_defaults["top_p"],
                step=0.01,
                label="top_p",
            )
            top_p.change(
                lambda state, value: fn(state, "top_p", value),
                inputs=[state, top_p],
                outputs=state,
            )

        with gr.Column():
            typical_p = gr.Slider(
                0.0,
                1.0,
                value=generation_defaults["typical_p"],
                step=0.01,
                label="typical_p",
            )
            typical_p.change(
                lambda state, value: fn(state, "typical_p", value),
                inputs=[state, typical_p],
                outputs=state,
            )

            repetition_penalty = gr.Slider(
                1.0,
                3.0,
                value=generation_defaults["repetition_penalty"],
                step=0.01,
                label="repetition_penalty",
            )
            repetition_penalty.change(
                lambda state, value: fn(state, "repetition_penalty", value),
                inputs=[state, repetition_penalty],
                outputs=state,
            )

            top_k = gr.Slider(
                0,
                100,
                value=generation_defaults["top_k"],
                step=1,
                label="top_k",
            )
            top_k.change(
                lambda state, value: fn(state, "top_k", value),
                inputs=[state, top_k],
                outputs=state,
            )

            if not for_kobold:
                penalty_alpha = gr.Slider(
                    0,
                    1,
                    value=generation_defaults["penalty_alpha"],
                    step=0.05,
                    label="penalty_alpha",
                )
                penalty_alpha.change(
                    lambda state, value: fn(state, "penalty_alpha", value),
                    inputs=[state, penalty_alpha],
                    outputs=state,
                )

    #
    # Some of these explanations are taken from Kobold:
    # https://github.com/KoboldAI/KoboldAI-Client/blob/main/gensettings.py
    #
    # They're passed directly into the `generate` call, so they should exist here:
    # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
    #
    with gr.Accordion(label="Helpful information", open=False):
        gr.Markdown("""
        Here's a basic rundown of each setting:

        - `max_new_tokens`: Number of tokens the AI should generate. Higher numbers will take longer to generate.
        - `temperature`: Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious.
        - `top_p`: Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious. (Put this value on 1 to disable its effect)
        - `top_k`: Alternative sampling method, can be combined with top_p. The number of highest probability vocabulary tokens to keep for top-k-filtering. (Put this value on 0 to disable its effect)
        - `typical_p`: Alternative sampling method described in the paper "Typical_p Decoding for Natural Language Generation" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect.
        - `repetition_penalty`: Used to penalize words that were already generated or belong to the context (Going over 1.2 breaks 6B models. Set to 1.0 to disable).
        - `penalty_alpha`: The alpha coefficient when using contrastive search.

        Some settings might not show up depending on which inference backend is being used.
        """)