File size: 3,940 Bytes
798b452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os

import dotenv
import gradio as gr  # type: ignore
from mistralai.client import MistralClient  # type: ignore
from mistralai.models.chat_completion import ChatMessage  # type: ignore

dotenv.load_dotenv()


MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")


TITLE = """<h1 align="center">MistralAI Playground 💬</h1>"""
AVATAR_IMAGES = (None, "https://media.roboflow.com/spaces/gemini-icon.png")


chatbot_component = gr.Chatbot(
    label="MistralAI", bubble_full_width=False, avatar_images=AVATAR_IMAGES, scale=2, height=400
)
text_prompt_component = gr.Textbox(placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8)
run_button_component = gr.Button(value="Run", variant="primary", scale=1)
mistral_key_component = gr.Textbox(
    label="MISTRAL API KEY",
    value="",
    type="password",
    placeholder="...",
    info="You have to provide your own MISTRAL_API_KEY for this app to function properly",
    visible=MISTRAL_API_KEY is None,
)
model_component = gr.Dropdown(
    choices=["mistral-tiny", "mistral-small", "mistral-medium"],
    label="Model",
    value="mistral-small",
    scale=1,
    type="value",
)
temperature_component = gr.Slider(
    minimum=0,
    maximum=1.0,
    value=0.7,
    step=0.05,
    label="Temperature",
    info=(
        "What sampling temperature to use, between 0.0 and 1.0. "
        "Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic."
        "We generally recommend altering this or top_p but not both."
    ),
)

user_inputs = [
    text_prompt_component,
    chatbot_component,
]
bot_inputs = [
    mistral_key_component,
    model_component,
    temperature_component,
    chatbot_component,
]


client: MistralClient = None


def preprocess_chat_history(history):
    chat_history = []
    for human, assistant in history:
        if human:
            chat_history.append(ChatMessage(role="user", content=human))
        if assistant:
            chat_history.append(ChatMessage(role="assistant", content=assistant))
    return chat_history


def bot(
    mistral_key: str | None,
    model: str,
    temperature: float,
    history,
):
    if not history:
        return history

    mistral_key = mistral_key or MISTRAL_API_KEY
    if not mistral_key:
        raise ValueError("MISTRAL_API_KEY is not set. Please follow the instructions in the README to set it up.")
    global client
    if client is None:
        client = MistralClient(api_key=mistral_key)  # TDOO: how to handle this if no GIL

    chat_history = preprocess_chat_history(history)
    history[-1][1] = ""
    for chunk in client.chat_stream(model=model, messages=chat_history, temperature=temperature):
        print("chunk", chunk)
        if chunk.choices and chunk.choices[0].delta.content:
            history[-1][1] += chunk.choices[0].delta.content
            yield history


def user(text_prompt: str, history):
    if text_prompt:
        history.append((text_prompt, None))
    return "", history


with gr.Blocks() as demo:
    gr.HTML(TITLE)
    with gr.Column():
        mistral_key_component.render()
        chatbot_component.render()
    with gr.Row():
        text_prompt_component.render()
        run_button_component.render()
    with gr.Accordion("Parameters", open=False):
        model_component.render()
        temperature_component.render()
    run_button_component.click(
        fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False
    ).then(
        fn=bot,
        inputs=bot_inputs,
        outputs=[chatbot_component],
    )

    text_prompt_component.submit(
        fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component], queue=False
    ).then(
        fn=bot,
        inputs=bot_inputs,
        outputs=[chatbot_component],
    )


demo.queue(max_size=99).launch(debug=False, show_error=True)