|
import os |
|
|
|
import dotenv |
|
import gradio as gr |
|
from mistralai.client import MistralClient |
|
from mistralai.models.chat_completion import ChatMessage |
|
|
|
dotenv.load_dotenv() |
|
|
|
|
|
MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY") |
|
|
|
|
|
TITLE = """<h1 align="center">MistralAI Playground 💬</h1>""" |
|
AVATAR_IMAGES = (None, "https://media.roboflow.com/spaces/gemini-icon.png") |
|
|
|
|
|
chatbot_component = gr.Chatbot( |
|
label="MistralAI", bubble_full_width=False, avatar_images=AVATAR_IMAGES, scale=2, height=400 |
|
) |
|
text_prompt_component = gr.Textbox(placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8) |
|
run_button_component = gr.Button(value="Run", variant="primary", scale=1) |
|
mistral_key_component = gr.Textbox( |
|
label="MISTRAL API KEY", |
|
value="", |
|
type="password", |
|
placeholder="...", |
|
info="You have to provide your own MISTRAL_API_KEY for this app to function properly", |
|
visible=MISTRAL_API_KEY is None, |
|
) |
|
model_component = gr.Dropdown( |
|
choices=["mistral-tiny", "mistral-small", "mistral-medium"], |
|
label="Model", |
|
value="mistral-small", |
|
scale=1, |
|
type="value", |
|
) |
|
temperature_component = gr.Slider( |
|
minimum=0, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.05, |
|
label="Temperature", |
|
info=( |
|
"What sampling temperature to use, between 0.0 and 1.0. " |
|
"Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic." |
|
"We generally recommend altering this or top_p but not both." |
|
), |
|
) |
|
|
|
user_inputs = [ |
|
text_prompt_component, |
|
chatbot_component, |
|
] |
|
bot_inputs = [ |
|
mistral_key_component, |
|
model_component, |
|
temperature_component, |
|
chatbot_component, |
|
] |
|
|
|
|
|
client: MistralClient = None |
|
|
|
|
|
def preprocess_chat_history(history): |
|
chat_history = [] |
|
for human, assistant in history: |
|
if human: |
|
chat_history.append(ChatMessage(role="user", content=human)) |
|
if assistant: |
|
chat_history.append(ChatMessage(role="assistant", content=assistant)) |
|
return chat_history |
|
|
|
|
|
def bot( |
|
mistral_key: str | None, |
|
model: str, |
|
temperature: float, |
|
history, |
|
): |
|
if not history: |
|
return history |
|
|
|
mistral_key = mistral_key or MISTRAL_API_KEY |
|
if not mistral_key: |
|
raise ValueError("MISTRAL_API_KEY is not set. Please follow the instructions in the README to set it up.") |
|
global client |
|
if client is None: |
|
client = MistralClient(api_key=mistral_key) |
|
|
|
chat_history = preprocess_chat_history(history) |
|
history[-1][1] = "" |
|
for chunk in client.chat_stream(model=model, messages=chat_history, temperature=temperature): |
|
print("chunk", chunk) |
|
if chunk.choices and chunk.choices[0].delta.content: |
|
history[-1][1] += chunk.choices[0].delta.content |
|
yield history |
|
|
|
|
|
def user(text_prompt: str, history): |
|
if text_prompt: |
|
history.append((text_prompt, None)) |
|
return "", history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(TITLE) |
|
with gr.Column(): |
|
mistral_key_component.render() |
|
chatbot_component.render() |
|
with gr.Row(): |
|
text_prompt_component.render() |
|
run_button_component.render() |
|
with gr.Accordion("Parameters", open=False): |
|
model_component.render() |
|
temperature_component.render() |
|
run_button_component.click( |
|
fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False |
|
).then( |
|
fn=bot, |
|
inputs=bot_inputs, |
|
outputs=[chatbot_component], |
|
) |
|
|
|
text_prompt_component.submit( |
|
fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False |
|
).then( |
|
fn=bot, |
|
inputs=bot_inputs, |
|
outputs=[chatbot_component], |
|
) |
|
|
|
|
|
demo.queue(max_size=99).launch(debug=False, show_error=True) |
|
|