Matt
Use partial for everything
8e3579a
raw
history blame
6.71 kB
import gradio as gr
from transformers import AutoTokenizer
import json
from functools import partial
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
demo_conversation = """[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hi there!"},
{"role": "assistant", "content": "Hello, human!"},
{"role": "user", "content": "Can I ask a question?"}
]"""
chat_templates = {
"chatml": """{% for message in messages %}
{{ "<|im_start|>" + message["role"] + "\\n" + message["content"] + "<|im_end|>\\n" }}
{% endfor %}
{% if add_generation_prompt %}
{{ "<|im_start|>assistant\\n" }}
{% endif %}""",
"zephyr": """{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>\n' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>\n' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>\n' + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}""",
"llama": """{% if messages[0]['role'] == 'system' %}
{% set loop_messages = messages[1:] %}
{% set system_message = messages[0]['content'] %}
{% else %}
{% set loop_messages = messages %}
{% set system_message = false %}
{% endif %}
{% for message in loop_messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if loop.index0 == 0 and system_message != false %}
{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}
{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content.strip() + ' ' + eos_token }}
{% endif %}
{% endfor %}""",
"alpaca": """{% for message in messages %}
{% if message['role'] == 'system' %}
{{ message['content'] + '\n\n' }}
{% elif message['role'] == 'user' %}
{{ '### Instruction:\n' + message['content'] + '\n\n' }}
{% elif message['role'] == 'assistant' %}
{{ '### Response:\n' + message['content'] + '\n\n' }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '### Response:\n' }}
{% endif %}
{% endfor %}""",
"vicuna": """{% for message in messages %}
{% if message['role'] == 'system' %}
{{ message['content'] + '\n' }}
{% elif message['role'] == 'user' %}
{{ 'USER:\n' + message['content'] + '\n' }}
{% elif message['role'] == 'assistant' %}
{{ 'ASSISTANT:\n' + message['content'] + '\n' }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ 'ASSISTANT:\n' }}
{% endif %}
{% endfor %}""",
"falcon": """{% for message in messages %}
{% if not loop.first %}
{{ '\n' }}
{% endif %}
{% if message['role'] == 'system' %}
{{ 'System: ' }}
{% elif message['role'] == 'user' %}
{{ 'User: ' }}
{% elif message['role'] == 'assistant' %}
{{ 'Falcon: ' }}
{% endif %}
{{ message['content'] }}
{% endfor %}
{% if add_generation_prompt %}
{{ '\n' + 'Falcon:' }}
{% endif %}"""
}
description_text = """# Chat Template Creator
### This space is a helper app for writing [Chat Templates](https://huggingface.co/docs/transformers/main/en/chat_templating).
### When you're happy with the outputs from your template, you can use the code block at the end to add it to a PR!"""
def apply_chat_template(template, test_conversation, add_generation_prompt, cleanup_whitespace):
if cleanup_whitespace:
template = "".join([line.strip() for line in template.split('\n')])
tokenizer.chat_template = template
outputs = []
conversation = json.loads(test_conversation)
pr_snippet = (
"CHECKPOINT = \"big-ai-company/cool-new-model\"\n"
"tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)",
f"tokenizer.chat_template = \"{template}\"",
"tokenizer.push_to_hub(CHECKPOINT, create_pr=True)"
)
pr_snippet = "\n".join(pr_snippet)
formatted = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=add_generation_prompt)
return formatted, pr_snippet
def load_template(template_name):
template_in.value = chat_templates[template_name]
with gr.Blocks() as demo:
gr.Markdown(description_text)
with gr.Row():
gr.Markdown("### Pick an existing template to start:")
with gr.Row():
load_chatml = gr.Button("ChatML")
load_zephyr = gr.Button("Zephyr")
load_llama = gr.Button("LLaMA")
with gr.Row():
load_alpaca = gr.Button("Alpaca")
load_vicuna = gr.Button("Vicuna")
load_falcon = gr.Button("Falcon")
with gr.Row():
with gr.Column():
template_in = gr.TextArea(value=chat_templates["chatml"], lines=10, max_lines=30, label="Chat Template")
conversation_in = gr.TextArea(value=demo_conversation, lines=6, label="Conversation")
generation_prompt_check = gr.Checkbox(value=False, label="Add generation prompt")
cleanup_whitespace_check = gr.Checkbox(value=True, label="Cleanup template whitespace")
submit = gr.Button("Apply template", variant="primary")
with gr.Column():
formatted_out = gr.TextArea(label="Formatted conversation")
code_snippet_out = gr.TextArea(label="Code snippet to create PR", lines=3, show_label=True, show_copy_button=True)
submit.click(fn=apply_chat_template,
inputs=[template_in, conversation_in, generation_prompt_check, cleanup_whitespace_check],
outputs=[formatted_out, code_snippet_out]
)
load_chatml.click(fn=partial(load_template, "chatml"))
load_zephyr.click(fn=partial(load_template, "zephyr"))
load_llama.click(fn=partial(load_template, "llama"))
load_alpaca.click(fn=partial(load_template, "alpaca"))
load_vicuna.click(fn=partial(load_template, "vicuna"))
load_falcon.click(fn=partial(load_template, "falcon"))
demo.launch()
#iface = gr.Interface(
# description=description_text,
# fn=apply_chat_template,
# inputs=[
# gr.TextArea(value=default_template, lines=10, max_lines=30, label="Chat Template"),
# gr.TextArea(value=demo_conversation, lines=6, label="Conversation"),
# gr.Checkbox(value=False, label="Add generation prompt"),
# gr.Checkbox(value=True, label="Cleanup template whitespace"),
# ],
# outputs=[
# gr.TextArea(label="Formatted conversation"),
# gr.TextArea(label="Code snippet to create PR", lines=3, show_label=True, show_copy_button=True)
# ]
#)
#iface.launch()