Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from openai import OpenAI | |
import jinja2 | |
from transformers import AutoTokenizer | |
# Initialize the OpenAI client | |
client = OpenAI( | |
base_url="https://api.hyperbolic.xyz/v1", | |
api_key=os.environ["HYPERBOLIC_API_KEY"], | |
) | |
# the tokenizer complains later after gradio forks without this setting. | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# use unofficial copy of Llama to avoid access restrictions. | |
tokenizer = AutoTokenizer.from_pretrained("mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated") | |
# Initial prompt | |
initial_prompts = { | |
"Default": ["405B", """A chat between a person and the Llama 3.1 405B base model. | |
"""], | |
} | |
# ChatML template | |
chatml_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}""" | |
chat_template = """{% for message in messages %}{{'<' + message['role'] + '>: ' + message['content'] + '\n'}}{% endfor %}""" | |
def format_chat(messages, use_chatml=False): | |
if use_chatml: | |
template = jinja2.Template(chatml_template) | |
else: | |
template = jinja2.Template(chat_template) | |
formatted = template.render(messages=messages) | |
return formatted | |
def count_tokens(text): | |
return len(tokenizer.encode(text)) | |
def limit_history(initial_prompt, history, new_message, max_tokens): | |
limited_history = [] | |
token_count = count_tokens(new_message) + count_tokens(initial_prompt) | |
if token_count > max_tokens: | |
raise(ValueError("message too large for context window")) | |
for user_msg, assistant_msg in reversed(history): | |
# TODO add ChatML wrapping here for better counting? | |
user_tokens = count_tokens(user_msg) | |
assistant_tokens = count_tokens(assistant_msg) | |
if token_count + user_tokens + assistant_tokens > max_tokens: | |
break | |
token_count += user_tokens + assistant_tokens | |
limited_history.insert(0, (user_msg, assistant_msg)) | |
return limited_history | |
def generate_response(message, history, initial_prompt, user_role, assistant_role, use_chatml): | |
context_length = 8192 | |
response_length = 1000 | |
slop_length = 300 # slop for chatml encoding etc--TODO fix this | |
# trim history based on token count | |
history_tokens = context_length - response_length - slop_length | |
limited_history = limit_history(initial_prompt, history, message, max_tokens=history_tokens) | |
# Prepare the input | |
chat_history = [{"role": user_role if i % 2 == 0 else assistant_role, "content": m} | |
for i, m in enumerate([item for sublist in limited_history for item in sublist] + [message])] | |
formatted_input = format_chat(chat_history, use_chatml) | |
if use_chatml: | |
full_prompt = "<|im_start|>system\n" + initial_prompt + "<|im_end|>\n" + formatted_input + f"<|im_start|>{assistant_role}\n" | |
else: | |
full_prompt = initial_prompt + "\n\n" + formatted_input + f"<{assistant_role}>:" | |
completion = client.completions.create( | |
model="meta-llama/Meta-Llama-3.1-405B", | |
prompt=full_prompt, | |
temperature=0.7, | |
frequency_penalty=0.1, | |
max_tokens=response_length, | |
stop=[f'<{user_role}>:', f'<{assistant_role}>:'] if not use_chatml else [f'<|im_end|>'] | |
) | |
assistant_response = completion.choices[0].text.strip() | |
return assistant_response | |
with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
with gr.Row(): | |
initial_prompt = gr.Textbox( | |
value="Please respond in whatever manner comes most naturally to you. You do not need to act as an assistant.", | |
label="Initial Prompt", | |
lines=3 | |
) | |
with gr.Column(): | |
user_role = gr.Textbox(value="user", label="User Role") | |
assistant_role = gr.Textbox(value="model", label="Assistant Role") | |
use_chatml = gr.Checkbox(label="Use ChatML", value=True) | |
chatbot = gr.ChatInterface( | |
generate_response, | |
title="Chat with 405B", | |
additional_inputs=[initial_prompt, user_role, assistant_role, use_chatml], | |
concurrency_limit=10, | |
chatbot=gr.Chatbot(height=600) | |
) | |
gr.Markdown(""" | |
This chat interface is powered by the Llama 3.1 405B base model, served by [Hyperbolic](https://hyperbolic.xyz), The Open Access AI Cloud. | |
Thank you to Hyperbolic for making this base model available! | |
""") | |
# Launch the interface | |
iface.launch(share=True, max_threads=40) | |