marcofrodl's picture
added new Mistral models
912cf21 verified
raw
history blame
4 kB
import gradio as gr
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
def get_stream_chat_completion(
message, chat_history, model, api_key, system=None, **kwargs
):
messages = []
if system is not None:
messages.append(ChatMessage(role="system", content=system))
for chat in chat_history:
human_message, bot_message = chat
messages.extend(
(
ChatMessage(role="user", content=human_message),
ChatMessage(role="assistant", content=bot_message),
)
)
messages.append(ChatMessage(role="user", content=message))
client = MistralClient(api_key=api_key)
for chunk in client.chat_stream(
model=model,
messages=messages,
**kwargs,
):
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
def respond_stream(
message,
chat_history,
api_key,
model,
temperature,
top_p,
max_tokens,
system,
):
response = ""
received_anything = False
for chunk in get_stream_chat_completion(
message=message,
chat_history=chat_history,
model=model,
api_key=api_key,
temperature=temperature,
top_p=top_p,
max_tokens=int(max_tokens),
system=system if system else None,
):
response += chunk
yield response
received_anything = True
if not received_anything:
gr.Warning("Error: Invalid API Key")
yield ""
css = """
.header-text p {line-height: 80px !important; text-align: left; font-size: 26px;}
.header-logo {text-align: left;}
.image-container img {max-width: 80px; height: auto;}
"""
with gr.Blocks(title="Mistral Playground", css=css) as mistral_playground:
with gr.Row():
with gr.Column(scale=1, min_width=80):
gr.Image("tt-logo.jpg", show_download_button=False, show_share_button=False, interactive=False, show_label=False, elem_id="thinktecture-logo", container=False)
with gr.Column(scale=11):
gr.Markdown("Thinktecture Mistral AI Playground", elem_classes="header-text")
with gr.Row(variant='panel'):
with gr.Column(scale=5):
api_key = gr.Textbox(type='password', placeholder='Your Mistral AI API key', lines=1, label="Mistral AI API Key")
with gr.Column(scale=7):
model = gr.Radio(
label="Mistral AI Model",
choices=[["7B","open-mistral-7b"], ["8x7B","open-mixtral-8x7b"], ["Small","mistral-small-latest"], ["Medium","mistral-medium-latest"], ["8x22B","open-mixtral-8x22b"], ["Large","mistral-large-latest"], ["Codestral","codestral-latest"]],
value="mistral-large-latest",
)
with gr.Row(variant='panel'):
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.1)
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95)
max_tokens = gr.Slider(label="Max Tokens", minimum=1000, maximum=32000, step=1000, value=8000)
with gr.Row(variant='panel'):
system = gr.Textbox(lines=2, label="System Message", value="You are a helpful AI assistant")
gr.ChatInterface(
respond_stream,
chatbot=gr.Chatbot(render=False, height=500, layout="panel"),
additional_inputs=[
api_key,
model,
temperature,
top_p,
max_tokens,
system,
],
)
with gr.Row():
gr.HTML(value="<p style='margin-top: 1rem; margin-bottom: 1rem; text-align: center;'>Developed by Marco Frodl, Principal Consultant for Generative AI @ <a href='https://go.mfr.one/tt-en' _target='blank'>Thinktecture AG</a> -- Released 06/09/2024 -- More about me on my <a href='https://go.mfr.one/marcofrodl-en' _target='blank'>profile page</a></p>")
mistral_playground.launch()