import json import logging import gradio as gr def get_generation_defaults(for_kobold): defaults = { "do_sample": True, "max_new_tokens": 196, "temperature": 0.5, "top_p": 0.9, "top_k": 0, "typical_p": 1.0, "repetition_penalty": 1.05, } if for_kobold: defaults.update({"max_context_length": 768}) else: defaults.update({"penalty_alpha": 0.6}) return defaults logger = logging.getLogger(__name__) def build_gradio_ui_for(inference_fn, for_kobold): ''' Builds a Gradio UI to interact with the model. Big thanks to TearGosling for the initial version that inspired this. ''' with gr.Blocks(title="Pygmalion", analytics_enabled=False) as interface: history_for_gradio = gr.State([]) history_for_model = gr.State([]) generation_settings = gr.State( get_generation_defaults(for_kobold=for_kobold)) def _update_generation_settings( original_settings, param_name, new_value, ): ''' Merges `{param_name: new_value}` into `original_settings` and returns a new dictionary. ''' updated_settings = {**original_settings, param_name: new_value} logging.debug("Generation settings updated to: `%s`", updated_settings) return updated_settings def _run_inference( model_history, gradio_history, user_input, generation_settings, *char_setting_states, ): ''' Runs inference on the model, and formats the returned response for the Gradio state and chatbot component. ''' char_name = char_setting_states[0] user_name = char_setting_states[1] # If user input is blank, format it as if user was silent if user_input is None or user_input.strip() == "": user_input = "..." inference_result = inference_fn(model_history, user_input, generation_settings, *char_setting_states) inference_result_for_gradio = inference_result \ .replace(f"{char_name}:", f"**{char_name}:**") \ .replace("", user_name) \ .replace("\n", "
") # Gradio chatbot component can display br tag as linebreak model_history.append(f"You: {user_input}") model_history.append(inference_result) gradio_history.append((user_input, inference_result_for_gradio)) return None, model_history, gradio_history, gradio_history def _regenerate( model_history, gradio_history, generation_settings, *char_setting_states, ): '''Regenerates the last response.''' return _run_inference( model_history[:-2], gradio_history[:-1], model_history[-2].replace("You: ", ""), generation_settings, *char_setting_states, ) def _undo_last_exchange(model_history, gradio_history): '''Undoes the last exchange (message pair).''' return model_history[:-2], gradio_history[:-1], gradio_history[:-1] def _save_chat_history(model_history, *char_setting_states): '''Saves the current chat history to a .json file.''' char_name = char_setting_states[0] with open(f"{char_name}_conversation.json", "w") as f: f.write(json.dumps({"chat": model_history})) return f"{char_name}_conversation.json" def _load_chat_history(file_obj, *char_setting_states): '''Loads up a chat history from a .json file.''' # ############################################################################################# # TODO(TG): Automatically detect and convert any CAI dump files loaded in to Pygmalion format # # ############################################################################################# # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list def pairwise(iterable): # "s -> (s0, s1), (s2, s3), (s4, s5), ..." a = iter(iterable) return zip(a, a) char_name = char_setting_states[0] user_name = char_setting_states[1] file_data = json.loads(file_obj.decode('utf-8')) model_history = file_data["chat"] # Construct a new gradio history new_gradio_history = [] for human_turn, bot_turn in pairwise(model_history): # Handle the situation where convo history may be loaded before character defs if char_name == "": # Grab char name from the model history char_name = bot_turn.split(":")[0] # Format the user and bot utterances user_turn = human_turn.replace("You: ", "") bot_turn = bot_turn.replace(f"{char_name}:", f"**{char_name}:**") # Somebody released a script on /g/ which tries to convert CAI dump logs # to Pygmalion character settings and chats. The anonymization of the dumps, however, means that # [NAME_IN_MESSAGE_REDACTED] is left in the conversational history. We obviously wouldn't want this # This therefore accomodates users of that script, so that [NAME_IN_MESSAGE_REDACTED] doesn't have # to be manually edited in the conversation JSON. # The model shouldn't generate [NAME_IN_MESSAGE_REDACTED] by itself. user_turn = user_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name) bot_turn = bot_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name) new_gradio_history.append((user_turn, bot_turn)) return model_history, new_gradio_history, new_gradio_history with gr.Tab("Character Settings") as settings_tab: charfile, char_setting_states = _build_character_settings_ui() with gr.Tab("Chat Window"): chatbot = gr.Chatbot( label="Your conversation will show up here").style( color_map=("#326efd", "#212528")) char_name, _user_name, char_persona, char_greeting, world_scenario, example_dialogue = char_setting_states charfile.upload( fn=_char_file_upload, inputs=[charfile, history_for_model, history_for_gradio], outputs=[history_for_model, history_for_gradio, chatbot, char_name, char_persona, char_greeting, world_scenario, example_dialogue] ) message = gr.Textbox( label="Your message (hit Enter to send)", placeholder="Write a message...", ) message.submit( fn=_run_inference, inputs=[ history_for_model, history_for_gradio, message, generation_settings, *char_setting_states ], outputs=[ message, history_for_model, history_for_gradio, chatbot ], ) with gr.Row(): send_btn = gr.Button("Send", variant="primary") send_btn.click( fn=_run_inference, inputs=[ history_for_model, history_for_gradio, message, generation_settings, *char_setting_states ], outputs=[ message, history_for_model, history_for_gradio, chatbot ], ) regenerate_btn = gr.Button("Regenerate") regenerate_btn.click( fn=_regenerate, inputs=[ history_for_model, history_for_gradio, generation_settings, *char_setting_states ], outputs=[ message, history_for_model, history_for_gradio, chatbot ], ) undo_btn = gr.Button("Undo last exchange") undo_btn.click( fn=_undo_last_exchange, inputs=[history_for_model, history_for_gradio], outputs=[history_for_model, history_for_gradio, chatbot], ) with gr.Row(): with gr.Column(): chatfile = gr.File(type="binary", file_types=[".json"], interactive=True) chatfile.upload( fn=_load_chat_history, inputs=[chatfile, *char_setting_states], outputs=[history_for_model, history_for_gradio, chatbot] ) save_char_btn = gr.Button(value="Save Conversation History") save_char_btn.click(_save_chat_history, inputs=[history_for_model, *char_setting_states], outputs=[chatfile]) with gr.Column(): gr.Markdown(""" ### To save a chat Click "Save Conversation History". The file will appear above the button and you can click to download it. ### To load a chat Drag a valid .json file onto the upload box, or click the box to browse. **Remember to fill out/load up your character definitions before resuming a chat!** """) with gr.Tab("Generation Settings"): _build_generation_settings_ui( state=generation_settings, fn=_update_generation_settings, for_kobold=for_kobold, ) return interface def _char_file_upload(file_obj, history_model, history_gradio): file_data = json.loads(file_obj.decode('utf-8')) char_name = file_data["char_name"] greeting = file_data["char_greeting"] empty_history = not history_model or (len(history_model) <= 2 and history_model[0] == '') if empty_history and char_name and greeting: # if chat history is empty so far, and there is a character greeting, add character greeting to the chat s = f'{char_name}: {greeting}' t = f'**{char_name}**: {greeting}' history_model = ['', s] history_gradio = [('', t)] return history_model, history_gradio, history_gradio, char_name, file_data["char_persona"], greeting, file_data["world_scenario"], file_data["example_dialogue"] def _build_character_settings_ui(): def char_file_create(char_name, char_persona, char_greeting, world_scenario, example_dialogue): with open(char_name + ".json", "w") as f: f.write(json.dumps({"char_name": char_name, "char_persona": char_persona, "char_greeting": char_greeting, "world_scenario": world_scenario, "example_dialogue": example_dialogue})) return char_name + ".json" with gr.Column(): with gr.Row(): char_name = gr.Textbox( label="Character Name", placeholder="The character's name", ) user_name = gr.Textbox( label="Your Name", placeholder="How the character should call you", ) char_persona = gr.Textbox( label="Character Persona", placeholder= "Describe the character's persona here. Think of this as CharacterAI's description + definitions in one box.", lines=4, ) char_greeting = gr.Textbox( label="Character Greeting", placeholder= "Write the character's greeting here. They will say this verbatim as their first response.", lines=3, ) world_scenario = gr.Textbox( label="Scenario", placeholder= "Optionally, describe the starting scenario in a few short sentences.", ) example_dialogue = gr.Textbox( label="Example Chat", placeholder= "Optionally, write in an example chat here. This is useful for showing how the character should behave, for example.", lines=4, ) with gr.Row(): with gr.Column(): charfile = gr.File(type="binary", file_types=[".json"]) save_char_btn = gr.Button(value="Generate Character File") save_char_btn.click(char_file_create, inputs=[char_name, char_persona, char_greeting, world_scenario, example_dialogue], outputs=[charfile]) with gr.Column(): gr.Markdown(""" ### To save a character Click "Generate Character File". The file will appear above the button and you can click to download it. ### To upload a character Drag a valid .json file onto the upload box, or click the box to browse. """) return charfile, (char_name, user_name, char_persona, char_greeting, world_scenario, example_dialogue) def _build_generation_settings_ui(state, fn, for_kobold): generation_defaults = get_generation_defaults(for_kobold=for_kobold) with gr.Row(): with gr.Column(): max_new_tokens = gr.Slider( 16, 512, value=generation_defaults["max_new_tokens"], step=4, label="max_new_tokens", ) max_new_tokens.change( lambda state, value: fn(state, "max_new_tokens", value), inputs=[state, max_new_tokens], outputs=state, ) temperature = gr.Slider( 0.1, 2, value=generation_defaults["temperature"], step=0.01, label="temperature", ) temperature.change( lambda state, value: fn(state, "temperature", value), inputs=[state, temperature], outputs=state, ) top_p = gr.Slider( 0.0, 1.0, value=generation_defaults["top_p"], step=0.01, label="top_p", ) top_p.change( lambda state, value: fn(state, "top_p", value), inputs=[state, top_p], outputs=state, ) with gr.Column(): typical_p = gr.Slider( 0.0, 1.0, value=generation_defaults["typical_p"], step=0.01, label="typical_p", ) typical_p.change( lambda state, value: fn(state, "typical_p", value), inputs=[state, typical_p], outputs=state, ) repetition_penalty = gr.Slider( 1.0, 3.0, value=generation_defaults["repetition_penalty"], step=0.01, label="repetition_penalty", ) repetition_penalty.change( lambda state, value: fn(state, "repetition_penalty", value), inputs=[state, repetition_penalty], outputs=state, ) top_k = gr.Slider( 0, 100, value=generation_defaults["top_k"], step=1, label="top_k", ) top_k.change( lambda state, value: fn(state, "top_k", value), inputs=[state, top_k], outputs=state, ) if not for_kobold: penalty_alpha = gr.Slider( 0, 1, value=generation_defaults["penalty_alpha"], step=0.05, label="penalty_alpha", ) penalty_alpha.change( lambda state, value: fn(state, "penalty_alpha", value), inputs=[state, penalty_alpha], outputs=state, ) # # Some of these explanations are taken from Kobold: # https://github.com/KoboldAI/KoboldAI-Client/blob/main/gensettings.py # # They're passed directly into the `generate` call, so they should exist here: # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig # with gr.Accordion(label="Helpful information", open=False): gr.Markdown(""" Here's a basic rundown of each setting: - `max_new_tokens`: Number of tokens the AI should generate. Higher numbers will take longer to generate. - `temperature`: Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious. - `top_p`: Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious. (Put this value on 1 to disable its effect) - `top_k`: Alternative sampling method, can be combined with top_p. The number of highest probability vocabulary tokens to keep for top-k-filtering. (Put this value on 0 to disable its effect) - `typical_p`: Alternative sampling method described in the paper "Typical_p Decoding for Natural Language Generation" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect. - `repetition_penalty`: Used to penalize words that were already generated or belong to the context (Going over 1.2 breaks 6B models. Set to 1.0 to disable). - `penalty_alpha`: The alpha coefficient when using contrastive search. Some settings might not show up depending on which inference backend is being used. """)