405chat / app.py
xlr8
update to bf16 model
1c78082
import os
import gradio as gr
from openai import OpenAI
import jinja2
from transformers import AutoTokenizer
# Initialize the OpenAI client
client = OpenAI(
base_url="https://api.hyperbolic.xyz/v1",
api_key=os.environ["HYPERBOLIC_API_KEY"],
)
# the tokenizer complains later after gradio forks without this setting.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# use unofficial copy of Llama to avoid access restrictions.
tokenizer = AutoTokenizer.from_pretrained("mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated")
# Initial prompt
initial_prompts = {
"Default": ["405B", """A chat between a person and the Llama 3.1 405B base model.
"""],
}
# ChatML template
chatml_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"""
chat_template = """{% for message in messages %}{{'<' + message['role'] + '>: ' + message['content'] + '\n'}}{% endfor %}"""
def format_chat(messages, use_chatml=False):
if use_chatml:
template = jinja2.Template(chatml_template)
else:
template = jinja2.Template(chat_template)
formatted = template.render(messages=messages)
return formatted
def count_tokens(text):
return len(tokenizer.encode(text))
def limit_history(initial_prompt, history, new_message, max_tokens):
limited_history = []
token_count = count_tokens(new_message) + count_tokens(initial_prompt)
if token_count > max_tokens:
raise(ValueError("message too large for context window"))
for user_msg, assistant_msg in reversed(history):
# TODO add ChatML wrapping here for better counting?
user_tokens = count_tokens(user_msg)
assistant_tokens = count_tokens(assistant_msg)
if token_count + user_tokens + assistant_tokens > max_tokens:
break
token_count += user_tokens + assistant_tokens
limited_history.insert(0, (user_msg, assistant_msg))
return limited_history
def generate_response(message, history, initial_prompt, user_role, assistant_role, use_chatml):
context_length = 8192
response_length = 1000
slop_length = 300 # slop for chatml encoding etc--TODO fix this
# trim history based on token count
history_tokens = context_length - response_length - slop_length
limited_history = limit_history(initial_prompt, history, message, max_tokens=history_tokens)
# Prepare the input
chat_history = [{"role": user_role if i % 2 == 0 else assistant_role, "content": m}
for i, m in enumerate([item for sublist in limited_history for item in sublist] + [message])]
formatted_input = format_chat(chat_history, use_chatml)
if use_chatml:
full_prompt = "<|im_start|>system\n" + initial_prompt + "<|im_end|>\n" + formatted_input + f"<|im_start|>{assistant_role}\n"
else:
full_prompt = initial_prompt + "\n\n" + formatted_input + f"<{assistant_role}>:"
completion = client.completions.create(
model="meta-llama/Meta-Llama-3.1-405B",
prompt=full_prompt,
temperature=0.7,
frequency_penalty=0.1,
max_tokens=response_length,
stop=[f'<{user_role}>:', f'<{assistant_role}>:'] if not use_chatml else [f'<|im_end|>']
)
assistant_response = completion.choices[0].text.strip()
return assistant_response
with gr.Blocks(theme=gr.themes.Soft()) as iface:
with gr.Row():
initial_prompt = gr.Textbox(
value="Please respond in whatever manner comes most naturally to you. You do not need to act as an assistant.",
label="Initial Prompt",
lines=3
)
with gr.Column():
user_role = gr.Textbox(value="user", label="User Role")
assistant_role = gr.Textbox(value="model", label="Assistant Role")
use_chatml = gr.Checkbox(label="Use ChatML", value=True)
chatbot = gr.ChatInterface(
generate_response,
title="Chat with 405B",
additional_inputs=[initial_prompt, user_role, assistant_role, use_chatml],
concurrency_limit=10,
chatbot=gr.Chatbot(height=600)
)
gr.Markdown("""
This chat interface is powered by the Llama 3.1 405B base model, served by [Hyperbolic](https://hyperbolic.xyz), The Open Access AI Cloud.
Thank you to Hyperbolic for making this base model available!
""")
# Launch the interface
iface.launch(share=True, max_threads=40)