import os import dotenv import gradio as gr # type: ignore from mistralai.client import MistralClient # type: ignore from mistralai.models.chat_completion import ChatMessage # type: ignore dotenv.load_dotenv() MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY") TITLE = """

MistralAI Playground 💬

""" AVATAR_IMAGES = (None, "https://media.roboflow.com/spaces/gemini-icon.png") chatbot_component = gr.Chatbot( label="MistralAI", bubble_full_width=False, avatar_images=AVATAR_IMAGES, scale=2, height=400 ) text_prompt_component = gr.Textbox(placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8) run_button_component = gr.Button(value="Run", variant="primary", scale=1) mistral_key_component = gr.Textbox( label="MISTRAL API KEY", value="", type="password", placeholder="...", info="You have to provide your own MISTRAL_API_KEY for this app to function properly", visible=MISTRAL_API_KEY is None, ) model_component = gr.Dropdown( choices=["mistral-tiny", "mistral-small", "mistral-medium"], label="Model", value="mistral-small", scale=1, type="value", ) temperature_component = gr.Slider( minimum=0, maximum=1.0, value=0.7, step=0.05, label="Temperature", info=( "What sampling temperature to use, between 0.0 and 1.0. " "Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic." "We generally recommend altering this or top_p but not both." ), ) user_inputs = [ text_prompt_component, chatbot_component, ] bot_inputs = [ mistral_key_component, model_component, temperature_component, chatbot_component, ] client: MistralClient = None def preprocess_chat_history(history): chat_history = [] for human, assistant in history: if human: chat_history.append(ChatMessage(role="user", content=human)) if assistant: chat_history.append(ChatMessage(role="assistant", content=assistant)) return chat_history def bot( mistral_key: str | None, model: str, temperature: float, history, ): if not history: return history mistral_key = mistral_key or MISTRAL_API_KEY if not mistral_key: raise ValueError("MISTRAL_API_KEY is not set. Please follow the instructions in the README to set it up.") global client if client is None: client = MistralClient(api_key=mistral_key) # TDOO: how to handle this if no GIL chat_history = preprocess_chat_history(history) history[-1][1] = "" for chunk in client.chat_stream(model=model, messages=chat_history, temperature=temperature): print("chunk", chunk) if chunk.choices and chunk.choices[0].delta.content: history[-1][1] += chunk.choices[0].delta.content yield history def user(text_prompt: str, history): if text_prompt: history.append((text_prompt, None)) return "", history with gr.Blocks() as demo: gr.HTML(TITLE) with gr.Column(): mistral_key_component.render() chatbot_component.render() with gr.Row(): text_prompt_component.render() run_button_component.render() with gr.Accordion("Parameters", open=False): model_component.render() temperature_component.render() run_button_component.click( fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False ).then( fn=bot, inputs=bot_inputs, outputs=[chatbot_component], ) text_prompt_component.submit( fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False ).then( fn=bot, inputs=bot_inputs, outputs=[chatbot_component], ) demo.queue(max_size=99).launch(debug=False, show_error=True)