Spaces:
Running
Running
import gradio as gr | |
from mistralai.client import MistralClient | |
from mistralai.models.chat_completion import ChatMessage | |
def get_stream_chat_completion( | |
message, chat_history, model, api_key, system=None, **kwargs | |
): | |
messages = [] | |
if system is not None: | |
messages.append(ChatMessage(role="system", content=system)) | |
for chat in chat_history: | |
human_message, bot_message = chat | |
messages.extend( | |
( | |
ChatMessage(role="user", content=human_message), | |
ChatMessage(role="assistant", content=bot_message), | |
) | |
) | |
messages.append(ChatMessage(role="user", content=message)) | |
client = MistralClient(api_key=api_key) | |
for chunk in client.chat_stream( | |
model=model, | |
messages=messages, | |
**kwargs, | |
): | |
if chunk.choices[0].delta.content is not None: | |
yield chunk.choices[0].delta.content | |
def respond_stream( | |
message, | |
chat_history, | |
api_key, | |
model, | |
temperature, | |
top_p, | |
max_tokens, | |
system, | |
): | |
response = "" | |
received_anything = False | |
for chunk in get_stream_chat_completion( | |
message=message, | |
chat_history=chat_history, | |
model=model, | |
api_key=api_key, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=int(max_tokens), | |
system=system if system else None, | |
): | |
response += chunk | |
yield response | |
received_anything = True | |
if not received_anything: | |
gr.Warning("Error: Invalid API Key") | |
yield "" | |
css = """ | |
.header-text p {line-height: 80px !important; text-align: left; font-size: 26px;} | |
.header-logo {text-align: left;} | |
.image-container img {max-width: 80px; height: auto;} | |
""" | |
with gr.Blocks(title="Mistral Playground", css=css) as mistral_playground: | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=80): | |
gr.Image("tt-logo.jpg", show_download_button=False, show_share_button=False, interactive=False, show_label=False, elem_id="thinktecture-logo", container=False) | |
with gr.Column(scale=11): | |
gr.Markdown("Thinktecture Mistral AI Playground", elem_classes="header-text") | |
with gr.Row(variant='panel'): | |
with gr.Column(scale=5): | |
api_key = gr.Textbox(type='password', placeholder='Your Mistral AI API key', lines=1, label="Mistral AI API Key") | |
with gr.Column(scale=7): | |
model = gr.Radio( | |
label="Mistral AI Model", | |
choices=[["7B","open-mistral-7b"], ["8x7B","open-mixtral-8x7b"], ["Small","mistral-small-latest"], ["Medium","mistral-medium-latest"], ["8x22B","open-mixtral-8x22b"], ["Large","mistral-large-latest"], ["Codestral","codestral-latest"]], | |
value="mistral-large-latest", | |
) | |
with gr.Row(variant='panel'): | |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.1) | |
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95) | |
max_tokens = gr.Slider(label="Max Tokens", minimum=1000, maximum=32000, step=1000, value=8000) | |
with gr.Row(variant='panel'): | |
system = gr.Textbox(lines=2, label="System Message", value="You are a helpful AI assistant") | |
gr.ChatInterface( | |
respond_stream, | |
chatbot=gr.Chatbot(render=False, height=500, layout="panel"), | |
additional_inputs=[ | |
api_key, | |
model, | |
temperature, | |
top_p, | |
max_tokens, | |
system, | |
], | |
) | |
with gr.Row(): | |
gr.HTML(value="<p style='margin-top: 1rem; margin-bottom: 1rem; text-align: center;'>Developed by Marco Frodl, Principal Consultant for Generative AI @ <a href='https://go.mfr.one/tt-en' _target='blank'>Thinktecture AG</a> -- Released 06/09/2024 -- More about me on my <a href='https://go.mfr.one/marcofrodl-en' _target='blank'>profile page</a></p>") | |
mistral_playground.launch() |