Spaces:
Sleeping
Sleeping
File size: 4,296 Bytes
9ab2b8f b920220 9ab2b8f 269a39e 9ab2b8f 269a39e 9ab2b8f 269a39e 9ab2b8f 269a39e 9ab2b8f 269a39e 9ab2b8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from huggingface_hub import InferenceClient
import gradio as gr
import os
API_URL = {
"Mistral" : "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3",
"Mixtral" : "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
"Mathstral" : "https://api-inference.huggingface.co/models/mistralai/mathstral-7B-v0.1",
}
HF_TOKEN = os.environ['HF_TOKEN']
mistralClient = InferenceClient(
API_URL["Mistral"],
headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
)
mixtralClient = InferenceClient(
model = API_URL["Mixtral"],
headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
)
mathstralClient = InferenceClient(
model = API_URL["Mathstral"],
headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
)
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95,
repetition_penalty=1.0, model = "Mathstral"):
# Selecting model to be used
if(model == "Mistral"):
client = mistralClient
elif(model == "Mixstral"):
client = mixtralClient
elif(model == "Mathstral"):
client = mathstralClient
temperature = float(temperature) # Generation arguments
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
additional_inputs=[
gr.Slider(
label="Temperature",
value=0.3,
minimum=0.0,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=1024,
minimum=0,
maximum=4096,
step=256,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
),
gr.Dropdown(
choices = ["Mistral","Mixtral", "Mathstral"],
value = "Mathstral",
label = "Le modèle à utiliser",
interactive=True,
info = "Mistral : pour des conversations génériques, "+
"Mixtral : conversations plus rapides et plus performantes, "+
"Mathstral : raisonnement mathématiques et scientifique"
),
]
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Mathstral Test</center><h1>")
gr.HTML("<h3><center>Dans cette démo, vous pouvez poser des questions mathématiques et scientifiques à Mathstral. 🧮</center><h3>")
gr.ChatInterface(
generate,
additional_inputs=additional_inputs,
theme = gr.themes.Soft(),
cache_examples=False,
examples=[ [l.strip()] for l in open("exercices.md").readlines()],
chatbot = gr.Chatbot(
latex_delimiters=[
{"left" : "$$", "right": "$$", "display": True },
{"left" : "\\[", "right": "\\]", "display": True },
{"left" : "\\(", "right": "\\)", "display": False },
{"left": "$", "right": "$", "display": False }
]
)
)
demo.queue(max_size=100).launch(debug=True)
|