Spaces:
Runtime error
Runtime error
File size: 18,377 Bytes
bddc905 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 |
import json
import logging
import gradio as gr
def get_generation_defaults(for_kobold):
defaults = {
"do_sample": True,
"max_new_tokens": 196,
"temperature": 0.5,
"top_p": 0.9,
"top_k": 0,
"typical_p": 1.0,
"repetition_penalty": 1.05,
}
if for_kobold:
defaults.update({"max_context_length": 768})
else:
defaults.update({"penalty_alpha": 0.6})
return defaults
logger = logging.getLogger(__name__)
def build_gradio_ui_for(inference_fn, for_kobold):
'''
Builds a Gradio UI to interact with the model. Big thanks to TearGosling for
the initial version that inspired this.
'''
with gr.Blocks(title="Pygmalion", analytics_enabled=False) as interface:
history_for_gradio = gr.State([])
history_for_model = gr.State([])
generation_settings = gr.State(
get_generation_defaults(for_kobold=for_kobold))
def _update_generation_settings(
original_settings,
param_name,
new_value,
):
'''
Merges `{param_name: new_value}` into `original_settings` and
returns a new dictionary.
'''
updated_settings = {**original_settings, param_name: new_value}
logging.debug("Generation settings updated to: `%s`",
updated_settings)
return updated_settings
def _run_inference(
model_history,
gradio_history,
user_input,
generation_settings,
*char_setting_states,
):
'''
Runs inference on the model, and formats the returned response for
the Gradio state and chatbot component.
'''
char_name = char_setting_states[0]
user_name = char_setting_states[1]
# If user input is blank, format it as if user was silent
if user_input is None or user_input.strip() == "":
user_input = "..."
inference_result = inference_fn(model_history, user_input,
generation_settings,
*char_setting_states)
inference_result_for_gradio = inference_result \
.replace(f"{char_name}:", f"**{char_name}:**") \
.replace("<USER>", user_name) \
.replace("\n", "<br>") # Gradio chatbot component can display br tag as linebreak
model_history.append(f"You: {user_input}")
model_history.append(inference_result)
gradio_history.append((user_input, inference_result_for_gradio))
return None, model_history, gradio_history, gradio_history
def _regenerate(
model_history,
gradio_history,
generation_settings,
*char_setting_states,
):
'''Regenerates the last response.'''
return _run_inference(
model_history[:-2],
gradio_history[:-1],
model_history[-2].replace("You: ", ""),
generation_settings,
*char_setting_states,
)
def _undo_last_exchange(model_history, gradio_history):
'''Undoes the last exchange (message pair).'''
return model_history[:-2], gradio_history[:-1], gradio_history[:-1]
def _save_chat_history(model_history, *char_setting_states):
'''Saves the current chat history to a .json file.'''
char_name = char_setting_states[0]
with open(f"{char_name}_conversation.json", "w") as f:
f.write(json.dumps({"chat": model_history}))
return f"{char_name}_conversation.json"
def _load_chat_history(file_obj, *char_setting_states):
'''Loads up a chat history from a .json file.'''
# #############################################################################################
# TODO(TG): Automatically detect and convert any CAI dump files loaded in to Pygmalion format #
# #############################################################################################
# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
def pairwise(iterable):
# "s -> (s0, s1), (s2, s3), (s4, s5), ..."
a = iter(iterable)
return zip(a, a)
char_name = char_setting_states[0]
user_name = char_setting_states[1]
file_data = json.loads(file_obj.decode('utf-8'))
model_history = file_data["chat"]
# Construct a new gradio history
new_gradio_history = []
for human_turn, bot_turn in pairwise(model_history):
# Handle the situation where convo history may be loaded before character defs
if char_name == "":
# Grab char name from the model history
char_name = bot_turn.split(":")[0]
# Format the user and bot utterances
user_turn = human_turn.replace("You: ", "")
bot_turn = bot_turn.replace(f"{char_name}:", f"**{char_name}:**")
# Somebody released a script on /g/ which tries to convert CAI dump logs
# to Pygmalion character settings and chats. The anonymization of the dumps, however, means that
# [NAME_IN_MESSAGE_REDACTED] is left in the conversational history. We obviously wouldn't want this
# This therefore accomodates users of that script, so that [NAME_IN_MESSAGE_REDACTED] doesn't have
# to be manually edited in the conversation JSON.
# The model shouldn't generate [NAME_IN_MESSAGE_REDACTED] by itself.
user_turn = user_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name)
bot_turn = bot_turn.replace("[NAME_IN_MESSAGE_REDACTED]", user_name)
new_gradio_history.append((user_turn, bot_turn))
return model_history, new_gradio_history, new_gradio_history
with gr.Tab("Character Settings") as settings_tab:
charfile, char_setting_states = _build_character_settings_ui()
with gr.Tab("Chat Window"):
chatbot = gr.Chatbot(
label="Your conversation will show up here").style(
color_map=("#326efd", "#212528"))
char_name, _user_name, char_persona, char_greeting, world_scenario, example_dialogue = char_setting_states
charfile.upload(
fn=_char_file_upload,
inputs=[charfile, history_for_model, history_for_gradio],
outputs=[history_for_model, history_for_gradio, chatbot, char_name, char_persona, char_greeting, world_scenario, example_dialogue]
)
message = gr.Textbox(
label="Your message (hit Enter to send)",
placeholder="Write a message...",
)
message.submit(
fn=_run_inference,
inputs=[
history_for_model, history_for_gradio, message,
generation_settings, *char_setting_states
],
outputs=[
message, history_for_model, history_for_gradio, chatbot
],
)
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
send_btn.click(
fn=_run_inference,
inputs=[
history_for_model, history_for_gradio, message,
generation_settings, *char_setting_states
],
outputs=[
message, history_for_model, history_for_gradio, chatbot
],
)
regenerate_btn = gr.Button("Regenerate")
regenerate_btn.click(
fn=_regenerate,
inputs=[
history_for_model, history_for_gradio,
generation_settings, *char_setting_states
],
outputs=[
message, history_for_model, history_for_gradio, chatbot
],
)
undo_btn = gr.Button("Undo last exchange")
undo_btn.click(
fn=_undo_last_exchange,
inputs=[history_for_model, history_for_gradio],
outputs=[history_for_model, history_for_gradio, chatbot],
)
with gr.Row():
with gr.Column():
chatfile = gr.File(type="binary", file_types=[".json"], interactive=True)
chatfile.upload(
fn=_load_chat_history,
inputs=[chatfile, *char_setting_states],
outputs=[history_for_model, history_for_gradio, chatbot]
)
save_char_btn = gr.Button(value="Save Conversation History")
save_char_btn.click(_save_chat_history, inputs=[history_for_model, *char_setting_states], outputs=[chatfile])
with gr.Column():
gr.Markdown("""
### To save a chat
Click "Save Conversation History". The file will appear above the button and you can click to download it.
### To load a chat
Drag a valid .json file onto the upload box, or click the box to browse.
**Remember to fill out/load up your character definitions before resuming a chat!**
""")
with gr.Tab("Generation Settings"):
_build_generation_settings_ui(
state=generation_settings,
fn=_update_generation_settings,
for_kobold=for_kobold,
)
return interface
def _char_file_upload(file_obj, history_model, history_gradio):
file_data = json.loads(file_obj.decode('utf-8'))
char_name = file_data["char_name"]
greeting = file_data["char_greeting"]
empty_history = not history_model or (len(history_model) <= 2 and history_model[0] == '')
if empty_history and char_name and greeting:
# if chat history is empty so far, and there is a character greeting, add character greeting to the chat
s = f'{char_name}: {greeting}'
t = f'**{char_name}**: {greeting}'
history_model = ['', s]
history_gradio = [('', t)]
return history_model, history_gradio, history_gradio, char_name, file_data["char_persona"], greeting, file_data["world_scenario"], file_data["example_dialogue"]
def _build_character_settings_ui():
def char_file_create(char_name, char_persona, char_greeting, world_scenario, example_dialogue):
with open(char_name + ".json", "w") as f:
f.write(json.dumps({"char_name": char_name, "char_persona": char_persona, "char_greeting": char_greeting, "world_scenario": world_scenario, "example_dialogue": example_dialogue}))
return char_name + ".json"
with gr.Column():
with gr.Row():
char_name = gr.Textbox(
label="Character Name",
placeholder="The character's name",
)
user_name = gr.Textbox(
label="Your Name",
placeholder="How the character should call you",
)
char_persona = gr.Textbox(
label="Character Persona",
placeholder=
"Describe the character's persona here. Think of this as CharacterAI's description + definitions in one box.",
lines=4,
)
char_greeting = gr.Textbox(
label="Character Greeting",
placeholder=
"Write the character's greeting here. They will say this verbatim as their first response.",
lines=3,
)
world_scenario = gr.Textbox(
label="Scenario",
placeholder=
"Optionally, describe the starting scenario in a few short sentences.",
)
example_dialogue = gr.Textbox(
label="Example Chat",
placeholder=
"Optionally, write in an example chat here. This is useful for showing how the character should behave, for example.",
lines=4,
)
with gr.Row():
with gr.Column():
charfile = gr.File(type="binary", file_types=[".json"])
save_char_btn = gr.Button(value="Generate Character File")
save_char_btn.click(char_file_create, inputs=[char_name, char_persona, char_greeting, world_scenario, example_dialogue], outputs=[charfile])
with gr.Column():
gr.Markdown("""
### To save a character
Click "Generate Character File". The file will appear above the button and you can click to download it.
### To upload a character
Drag a valid .json file onto the upload box, or click the box to browse.
""")
return charfile, (char_name, user_name, char_persona, char_greeting, world_scenario, example_dialogue)
def _build_generation_settings_ui(state, fn, for_kobold):
generation_defaults = get_generation_defaults(for_kobold=for_kobold)
with gr.Row():
with gr.Column():
max_new_tokens = gr.Slider(
16,
512,
value=generation_defaults["max_new_tokens"],
step=4,
label="max_new_tokens",
)
max_new_tokens.change(
lambda state, value: fn(state, "max_new_tokens", value),
inputs=[state, max_new_tokens],
outputs=state,
)
temperature = gr.Slider(
0.1,
2,
value=generation_defaults["temperature"],
step=0.01,
label="temperature",
)
temperature.change(
lambda state, value: fn(state, "temperature", value),
inputs=[state, temperature],
outputs=state,
)
top_p = gr.Slider(
0.0,
1.0,
value=generation_defaults["top_p"],
step=0.01,
label="top_p",
)
top_p.change(
lambda state, value: fn(state, "top_p", value),
inputs=[state, top_p],
outputs=state,
)
with gr.Column():
typical_p = gr.Slider(
0.0,
1.0,
value=generation_defaults["typical_p"],
step=0.01,
label="typical_p",
)
typical_p.change(
lambda state, value: fn(state, "typical_p", value),
inputs=[state, typical_p],
outputs=state,
)
repetition_penalty = gr.Slider(
1.0,
3.0,
value=generation_defaults["repetition_penalty"],
step=0.01,
label="repetition_penalty",
)
repetition_penalty.change(
lambda state, value: fn(state, "repetition_penalty", value),
inputs=[state, repetition_penalty],
outputs=state,
)
top_k = gr.Slider(
0,
100,
value=generation_defaults["top_k"],
step=1,
label="top_k",
)
top_k.change(
lambda state, value: fn(state, "top_k", value),
inputs=[state, top_k],
outputs=state,
)
if not for_kobold:
penalty_alpha = gr.Slider(
0,
1,
value=generation_defaults["penalty_alpha"],
step=0.05,
label="penalty_alpha",
)
penalty_alpha.change(
lambda state, value: fn(state, "penalty_alpha", value),
inputs=[state, penalty_alpha],
outputs=state,
)
#
# Some of these explanations are taken from Kobold:
# https://github.com/KoboldAI/KoboldAI-Client/blob/main/gensettings.py
#
# They're passed directly into the `generate` call, so they should exist here:
# https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
#
with gr.Accordion(label="Helpful information", open=False):
gr.Markdown("""
Here's a basic rundown of each setting:
- `max_new_tokens`: Number of tokens the AI should generate. Higher numbers will take longer to generate.
- `temperature`: Randomness of sampling. High values can increase creativity but may make text less sensible. Lower values will make text more predictable but can become repetitious.
- `top_p`: Used to discard unlikely text in the sampling process. Lower values will make text more predictable but can become repetitious. (Put this value on 1 to disable its effect)
- `top_k`: Alternative sampling method, can be combined with top_p. The number of highest probability vocabulary tokens to keep for top-k-filtering. (Put this value on 0 to disable its effect)
- `typical_p`: Alternative sampling method described in the paper "Typical_p Decoding for Natural Language Generation" (10.48550/ARXIV.2202.00666). The paper suggests 0.2 as a good value for this setting. Set this setting to 1 to disable its effect.
- `repetition_penalty`: Used to penalize words that were already generated or belong to the context (Going over 1.2 breaks 6B models. Set to 1.0 to disable).
- `penalty_alpha`: The alpha coefficient when using contrastive search.
Some settings might not show up depending on which inference backend is being used.
""")
|