File size: 4,662 Bytes
798b452 a4b430d 798b452 d3f72ee 798b452 a4b430d 798b452 d3f72ee 798b452 d3f72ee 798b452 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import dotenv
import gradio as gr # type: ignore
from mistralai.client import MistralClient # type: ignore
from mistralai.models.chat_completion import ChatMessage # type: ignore
dotenv.load_dotenv()
MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")
TITLE = """<h1 align="center">MistralAI Playground 💬</h1>"""
DUPLICATE = """
<div style="text-align: center; display: flex; justify-content: center; align-items: center;">
<a href="https://huggingface.co/spaces/douglarek/MistralAI?duplicate=true">
<img src="https://bit.ly/3gLdBN6" alt="Duplicate Space" style="margin-right: 10px;">
</a>
<span>Duplicate the Space and run securely with your
<a href="https://console.mistral.ai/user/api-keys"> Mistral API KEY</a>.
</span>
</div>
"""
AVATAR_IMAGES = (None, "https://media.roboflow.com/spaces/gemini-icon.png")
chatbot_component = gr.Chatbot(
label="MistralAI", bubble_full_width=False, avatar_images=AVATAR_IMAGES, scale=2, height=400
)
text_prompt_component = gr.Textbox(placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8)
run_button_component = gr.Button(value="Run", variant="primary", scale=1)
clear_button_component = gr.ClearButton(value="Clear", variant="secondary", scale=1)
mistral_key_component = gr.Textbox(
label="MISTRAL API KEY",
value="",
type="password",
placeholder="...",
info="You have to provide your own MISTRAL_API_KEY for this app to function properly",
visible=MISTRAL_API_KEY is None,
)
model_component = gr.Dropdown(
choices=["mistral-tiny", "mistral-small", "mistral-medium"],
label="Model",
value="mistral-small",
scale=1,
type="value",
)
temperature_component = gr.Slider(
minimum=0,
maximum=1.0,
value=0.7,
step=0.05,
label="Temperature",
info=(
"What sampling temperature to use, between 0.0 and 1.0. "
"Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic."
"We generally recommend altering this or top_p but not both."
),
)
user_inputs = [
text_prompt_component,
chatbot_component,
]
bot_inputs = [
mistral_key_component,
model_component,
temperature_component,
chatbot_component,
]
client: MistralClient = None
def preprocess_chat_history(history):
chat_history = []
for human, assistant in history:
if human:
chat_history.append(ChatMessage(role="user", content=human))
if assistant:
chat_history.append(ChatMessage(role="assistant", content=assistant))
return chat_history
def bot(
mistral_key: str | None,
model: str,
temperature: float,
history,
):
if not history:
return history
mistral_key = mistral_key or MISTRAL_API_KEY
if not mistral_key:
raise ValueError("MISTRAL_API_KEY is not set. Please follow the instructions in the README to set it up.")
global client
if client is None:
client = MistralClient(api_key=mistral_key) # TDOO: how to handle this if no GIL
chat_history = preprocess_chat_history(history)
history[-1][1] = ""
for chunk in client.chat_stream(model=model, messages=chat_history, temperature=temperature):
print("chunk", chunk)
if chunk.choices and chunk.choices[0].delta.content:
history[-1][1] += chunk.choices[0].delta.content
yield history
def user(text_prompt: str, history):
if text_prompt:
history.append((text_prompt, None))
return "", history
with gr.Blocks() as demo:
gr.HTML(TITLE)
gr.HTML(DUPLICATE)
with gr.Column():
mistral_key_component.render()
chatbot_component.render()
with gr.Row():
text_prompt_component.render()
run_button_component.render()
clear_button_component.render()
with gr.Accordion("Parameters", open=False):
model_component.render()
temperature_component.render()
run_button_component.click(
fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False
).then(
fn=bot,
inputs=bot_inputs,
outputs=[chatbot_component],
)
clear_button_component.click(lambda: (None, None), outputs=[text_prompt_component, chatbot_component], queue=False)
text_prompt_component.submit(
fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False
).then(
fn=bot,
inputs=bot_inputs,
outputs=[chatbot_component],
)
demo.queue(max_size=99).launch(debug=False, show_error=True)
|