gradiopyg / src /gradio_ui.py
dorkai's picture
Upload 13 files
bddc905
raw
history blame
18.4 kB
import json
import logging
import gradio as gr
def get_generation_defaults(for_kobold):
defaults = {
"do_sample": True,
"max_new_tokens": 196,
"temperature": 0.5,
"top_p": 0.9,
"top_k": 0,
"typical_p": 1.0,
"repetition_penalty": 1.05,
}
if for_kobold:
defaults.update({"max_context_length": 768})
else:
defaults.update({"penalty_alpha": 0.6})
return defaults
logger = logging.getLogger(__name__)
def build_gradio_ui_for(inference_fn, for_kobold):
'''
Builds a Gradio UI to interact with the model. Big thanks to TearGosling for
the initial version that inspired this.
'''
with gr.Blocks(title="Pygmalion", analytics_enabled=False) as interface:
history_for_gradio = gr.State([])
history_for_model = gr.State([])
generation_settings = gr.State(
get_generation_defaults(for_kobold=for_kobold))
def _update_generation_settings(
original_settings,
param_name,
new_value,
):
'''
Merges `{param_name: new_value}` into `original_settings` and
returns a new dictionary.
'''
updated_settings = {**original_settings, param_name: new_value}
logging.debug("Generation settings updated to: `%s`",
updated_settings)
return updated_settings
def _run_inference(
model_history,
gradio_history,
user_input,
generation_settings,
*char_setting_states,
):
'''
Runs inference on the model, and formats the returned response for
the Gradio state and chatbot component.
'''
char_name = char_setting_states[0]
user_name = char_setting_states[1]
# If user input is blank, format it as if user was silent
if user_input is None or user_input.strip() == "":
user_input = "..."
inference_result = inference_fn(model_history, user_input,
generation_settings,
*char_setting_states)
inference_result_for_gradio = inference_result \
.replace(f"{char_name}:", f"**{char_name}:**") \
.replace("<USER>", user_name) \
.replace("\n", "<br>") # Gradio chatbot component can display br tag as linebreak
model_history.append(f"You: {user_input}")
model_history.append(inference_result)
gradio_history.append((user_input, inference_result_for_gradio))
return None, model_history, gradio_history, gradio_history
def _regenerate(
model_history,
gradio_history,
generation_settings,
*char_setting_states,
):
'''Regenerates the last response.'''
return _run_inference(
model_history[:-2],
gradio_history[:-1],
model_history[-2].replace("You: ", ""),
generation_settings,
*char_setting_states,
)
def _undo_last_exchange(model_history, gradio_history):
'''Undoes the last exchange (message pair).'''
return model_history[:-2], gradio_history[:-1], gradio_history[:-1]
def _save_chat_history(model_history, *char_setting_states):
'''Saves the current chat history to a .json file.'''
char_name = char_setting_states[0]
with open(f"{char_name}_conversation.json", "w") as f:
f.write(json.dumps({"chat": model_history}))
return f"{char_name}_conversation.json"
def _load_chat_history(file_obj, *char_setting_states):
'''Loads up a chat history from a .json file.'''
# #############################################################################################
# TODO(TG): Automatically detect and convert any CAI dump files loaded in to Pygmalion format #
# #############################################################################################
# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
def pairwise(iterable):
# "s -> (s0, s1), (s2, s3), (s4, s5), ..."
a = iter(iterable)
return zip(a, a)
char_name = char_setting_states[0]
user_name = char_setting_states[1]
file_data = json.loads(file_obj.decode('utf-8'))
model_history = file_data["chat"]
# Construct a new gradio history
new_gradio_history = []
for human_turn, bot_turn in pairwise(model_history):
# Handle the situation where convo history may be loaded before character defs
if char_name == "":
# Grab char name from the model history
char_name = bot_turn.split(":")[0]
# Format the user and bot utterances
user_turn = human_turn.replace("You: ", "")
bot_turn = bot_turn.replace(f"{char_name}:", f"**{char_name}:**")
# Somebody released a script on /g/ which tries to convert CAI dump logs
# to Pygmalion character settings and chats. The anonymization of the dumps, however, means that
# [NAME_IN_MESSAGE_REDACTED] is left in the conversational history. We obviously wouldn't want this
# This therefore accomodates users of that script, so that [NAME_IN_MESSAGE_REDACTED] doesn't have
# to be manually edited in the conversation JSON.
# The model shouldn't generate [NAME_IN_MESSAGE_REDACTED] by itself.
user_turn = user_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name)
bot_turn = bot_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name)
new_gradio_history.append((user_turn, bot_turn))
return model_history, new_gradio_history, new_gradio_history
with gr.Tab("Character Settings") as settings_tab:
charfile, char_setting_states = _build_character_settings_ui()
with gr.Tab("Chat Window"):
chatbot = gr.Chatbot(
label="Your conversation will show up here").style(
color_map=("#326efd", "#212528"))
char_name, _user_name, char_persona, char_greeting, world_scenario, example_dialogue = char_setting_states
charfile.upload(
fn=_char_file_upload,
inputs=[charfile, history_for_model, history_for_gradio],
outputs=[history_for_model, history_for_gradio, chatbot, char_name, char_persona, char_greeting, world_scenario, example_dialogue]
)
message = gr.Textbox(
label="Your message (hit Enter to send)",
placeholder="Write a message...",
)
message.submit(
fn=_run_inference,
inputs=[
history_for_model, history_for_gradio, message,
generation_settings, *char_setting_states
],
outputs=[
message, history_for_model, history_for_gradio, chatbot
],
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
send_btn.click(
fn=_run_inference,
inputs=[
history_for_model, history_for_gradio, message,
generation_settings, *char_setting_states
],
outputs=[
message, history_for_model, history_for_gradio, chatbot
],
)
regenerate_btn = gr.Button("Regenerate")
regenerate_btn.click(
fn=_regenerate,
inputs=[
history_for_model, history_for_gradio,
generation_settings, *char_setting_states
],
outputs=[
message, history_for_model, history_for_gradio, chatbot
],
)
undo_btn = gr.Button("Undo last exchange")
undo_btn.click(
fn=_undo_last_exchange,
inputs=[history_for_model, history_for_gradio],
outputs=[history_for_model, history_for_gradio, chatbot],
)
with gr.Row():
with gr.Column():
chatfile = gr.File(type="binary", file_types=[".json"], interactive=True)
chatfile.upload(
fn=_load_chat_history,
inputs=[chatfile, *char_setting_states],
outputs=[history_for_model, history_for_gradio, chatbot]
)
save_char_btn = gr.Button(value="Save Conversation History")
save_char_btn.click(_save_chat_history, inputs=[history_for_model, *char_setting_states], outputs=[chatfile])
with gr.Column():
gr.Markdown("""
### To save a chat
Click "Save Conversation History". The file will appear above the button and you can click to download it.
### To load a chat
Drag a valid .json file onto the upload box, or click the box to browse.
**Remember to fill out/load up your character definitions before resuming a chat!**
""")
with gr.Tab("Generation Settings"):
_build_generation_settings_ui(
state=generation_settings,
fn=_update_generation_settings,
for_kobold=for_kobold,
)
return interface
def _char_file_upload(file_obj, history_model, history_gradio):
file_data = json.loads(file_obj.decode('utf-8'))
char_name = file_data["char_name"]
greeting = file_data["char_greeting"]
empty_history = not history_model or (len(history_model) <= 2 and history_model[0] == '')
if empty_history and char_name and greeting:
# if chat history is empty so far, and there is a character greeting, add character greeting to the chat
s = f'{char_name}: {greeting}'
t = f'**{char_name}**: {greeting}'
history_model = ['', s]
history_gradio = [('', t)]
return history_model, history_gradio, history_gradio, char_name, file_data["char_persona"], greeting, file_data["world_scenario"], file_data["example_dialogue"]
def _build_character_settings_ui():
def char_file_create(char_name, char_persona, char_greeting, world_scenario, example_dialogue):
with open(char_name + ".json", "w") as f:
f.write(json.dumps({"char_name": char_name, "char_persona": char_persona, "char_greeting": char_greeting, "world_scenario": world_scenario, "example_dialogue": example_dialogue}))
return char_name + ".json"
with gr.Column():
with gr.Row():
char_name = gr.Textbox(
label="Character Name",
placeholder="The character's name",
)
user_name = gr.Textbox(
label="Your Name",
placeholder="How the character should call you",
)
char_persona = gr.Textbox(
label="Character Persona",
placeholder=
"Describe the character's persona here. Think of this as CharacterAI's description + definitions in one box.",
lines=4,
)
char_greeting = gr.Textbox(
label="Character Greeting",
placeholder=
"Write the character's greeting here. They will say this verbatim as their first response.",
lines=3,
)
world_scenario = gr.Textbox(
label="Scenario",
placeholder=
"Optionally, describe the starting scenario in a few short sentences.",
)
example_dialogue = gr.Textbox(
label="Example Chat",
placeholder=
"Optionally, write in an example chat here. This is useful for showing how the character should behave, for example.",
lines=4,
)
with gr.Row():
with gr.Column():
charfile = gr.File(type="binary", file_types=[".json"])
save_char_btn = gr.Button(value="Generate Character File")
save_char_btn.click(char_file_create, inputs=[char_name, char_persona, char_greeting, world_scenario, example_dialogue], outputs=[charfile])
with gr.Column():
gr.Markdown("""
### To save a character
Click "Generate Character File". The file will appear above the button and you can click to download it.
### To upload a character
Drag a valid .json file onto the upload box, or click the box to browse.
""")
return charfile, (char_name, user_name, char_persona, char_greeting, world_scenario, example_dialogue)
def _build_generation_settings_ui(state, fn, for_kobold):
generation_defaults = get_generation_defaults(for_kobold=for_kobold)
with gr.Row():
with gr.Column():
max_new_tokens = gr.Slider(
16,
512,
value=generation_defaults["max_new_tokens"],
step=4,
label="max_new_tokens",
)
max_new_tokens.change(
lambda state, value: fn(state, "max_new_tokens", value),
inputs=[state, max_new_tokens],
outputs=state,
)
temperature = gr.Slider(
0.1,
2,
value=generation_defaults["temperature"],
step=0.01,
label="temperature",
)
temperature.change(
lambda state, value: fn(state, "temperature", value),
inputs=[state, temperature],
outputs=state,
)
top_p = gr.Slider(
0.0,
1.0,
value=generation_defaults["top_p"],
step=0.01,
label="top_p",
)
top_p.change(
lambda state, value: fn(state, "top_p", value),
inputs=[state, top_p],
outputs=state,
)
with gr.Column():
typical_p = gr.Slider(
0.0,
1.0,
value=generation_defaults["typical_p"],
step=0.01,
label="typical_p",
)
typical_p.change(
lambda state, value: fn(state, "typical_p", value),
inputs=[state, typical_p],
outputs=state,
)
repetition_penalty = gr.Slider(
1.0,
3.0,
value=generation_defaults["repetition_penalty"],
step=0.01,
label="repetition_penalty",
)
repetition_penalty.change(
lambda state, value: fn(state, "repetition_penalty", value),
inputs=[state, repetition_penalty],
outputs=state,
)
top_k = gr.Slider(
0,
100,
value=generation_defaults["top_k"],
step=1,
label="top_k",
)
top_k.change(
lambda state, value: fn(state, "top_k", value),
inputs=[state, top_k],
outputs=state,
)
if not for_kobold:
penalty_alpha = gr.Slider(
0,
1,
value=generation_defaults["penalty_alpha"],
step=0.05,
label="penalty_alpha",
)
penalty_alpha.change(
lambda state, value: fn(state, "penalty_alpha", value),
inputs=[state, penalty_alpha],
outputs=state,
)
#
# Some of these explanations are taken from Kobold:
# https://github.com/KoboldAI/KoboldAI-Client/blob/main/gensettings.py
#
# They're passed directly into the `generate` call, so they should exist here:
# https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
#
with gr.Accordion(label="Helpful information", open=False):
gr.Markdown("""
Here's a basic rundown of each setting:
- `max_new_tokens`: Number of tokens the AI should generate. Higher numbers will take longer to generate.
- `temperature`: Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious.
- `top_p`: Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious. (Put this value on 1 to disable its effect)
- `top_k`: Alternative sampling method, can be combined with top_p. The number of highest probability vocabulary tokens to keep for top-k-filtering. (Put this value on 0 to disable its effect)
- `typical_p`: Alternative sampling method described in the paper "Typical_p Decoding for Natural Language Generation" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect.
- `repetition_penalty`: Used to penalize words that were already generated or belong to the context (Going over 1.2 breaks 6B models. Set to 1.0 to disable).
- `penalty_alpha`: The alpha coefficient when using contrastive search.
Some settings might not show up depending on which inference backend is being used.
""")