|
from openai import OpenAI |
|
import streamlit as st |
|
import numpy as np |
|
from PIL import Image |
|
from time import perf_counter |
|
import itertools |
|
|
|
|
|
st.set_page_config( |
|
page_title= "Unify Router Demo", |
|
page_icon="./assets/unify_spiral.png", |
|
layout = "wide", |
|
initial_sidebar_state="collapsed" |
|
) |
|
router_avatar = np.array(Image.open('./assets/unify_spiral.png')) |
|
|
|
|
|
with open( "./style.css" ) as css: |
|
st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True) |
|
|
|
|
|
st.info( |
|
body="This demo is only a preview. Check out our [Chat UI](https://unify.ai/chat) for the full experience, including more endpoints, and extra customization!", |
|
icon="βΉοΈ" |
|
) |
|
|
|
|
|
strategies = { |
|
'π fastest': "tks-per-sec", |
|
'β most responsive': "ttft", |
|
"π΅ cheapest": "input-cost", |
|
} |
|
models = { |
|
'π¦ Llama2 70B Chat': "llama-2-70b-chat", |
|
'π¨ Mixtral 8x7B Instruct': "mixtral-8x7b-instruct-v0.1", |
|
'π Gemma 7B': "gemma-7b-it", |
|
} |
|
|
|
|
|
Parameters_Col, Chat_Col = st.columns([1,3]) |
|
|
|
with Parameters_Col: |
|
|
|
st.image( |
|
"./assets/unify_logo.png", |
|
use_column_width="auto", |
|
) |
|
st.markdown("Send your prompts to the best LLM endpoint and optimize performance, all with a **single API**") |
|
|
|
strategy = st.selectbox( |
|
label = 'I want the', |
|
options = tuple(strategies.keys()), |
|
help="Choose the metric to optimize the routing for. \ |
|
Fastest picks the endpoint with the highest output tokens per seconds. \ |
|
Most responsive picks the endpoint with the smallest time to complete the request. \ |
|
Cheapest picks the endpoint with the lowest output tokens cost", |
|
) |
|
model = st.selectbox( |
|
label = 'endpoint for', |
|
options = tuple(models.keys()), |
|
help="Select a model to optimize for. The same model can be offered by different model endpoint providers. The router lets you find the optimal endpoint for your chosen model, target metric, and input prompt", |
|
) |
|
with st.expander("Advanced Inputs"): |
|
max_tokens = st.slider( |
|
label = "Maximum Number Of Tokens", |
|
min_value=100, |
|
max_value=2000, |
|
value=500, |
|
step=100, |
|
help = "The maximum number of tokens that can be generated." |
|
) |
|
temperature = st.slider( |
|
label = "Temperature", |
|
min_value=0.0, |
|
max_value=1., |
|
value=0.5, |
|
step=0.5, |
|
help = "The model's output randomness. Higher values give more random outputs." |
|
) |
|
|
|
with Chat_Col: |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
msgs = st.container(height = 350) |
|
|
|
|
|
for msg in st.session_state.messages: |
|
if msg["role"] == "user": |
|
msgs.chat_message(msg["role"]).write(msg["content"]) |
|
else: |
|
msgs.chat_message(msg["role"], avatar=router_avatar).write(msg["content"]) |
|
|
|
|
|
client = OpenAI( |
|
base_url="https://api.unify.ai/v0/", |
|
api_key=st.secrets["UNIFY_API"] |
|
) |
|
|
|
|
|
if prompt := st.chat_input("Enter your prompt.."): |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with msgs.chat_message("user"): |
|
st.write(prompt) |
|
|
|
|
|
with msgs.status("Routing your prompt..",expanded=True): |
|
|
|
start = perf_counter() |
|
stream = client.chat.completions.create( |
|
model="@".join([ |
|
models[model], |
|
strategies[strategy] |
|
]), |
|
messages=[ |
|
{"role": m["role"], "content": m["content"]} |
|
for m in st.session_state.messages |
|
], |
|
stream=True, |
|
max_tokens=max_tokens, |
|
temperature=temperature |
|
) |
|
time_to_completion = round(perf_counter() - start, 2) |
|
|
|
|
|
stream, stream_copy = itertools.tee(stream) |
|
st.write_stream(stream) |
|
chunks = [chunk for chunk in stream_copy] |
|
|
|
|
|
last_chunk = chunks[-1] |
|
cost = round(last_chunk.usage["cost"],6) |
|
output_tokens = last_chunk.usage["completion_tokens"] |
|
tokens_per_second = round(output_tokens / time_to_completion, 2) |
|
|
|
|
|
provider = " ".join(chunks[0].model.split("@")[-1].split("-")).title() |
|
if " Ai" in provider: |
|
provider = provider.replace("Ai", "AI") |
|
st.markdown(f"Model: **{model}**. Provider: **{provider}**") |
|
st.markdown( |
|
f"**{tokens_per_second}** Tokens Per Second - \ |
|
**{time_to_completion}** Seconds to complete - \ |
|
**{cost:.6f}** $" |
|
) |
|
|
|
|
|
output_chunks = [chunk.choices[0].delta.content or "" for chunk in chunks] |
|
response = ''.join(output_chunks) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
if st.button("Clear Chat", key="clear"): |
|
msgs.empty() |
|
st.session_state.messages = [] |