File size: 2,767 Bytes
ce78f1b
 
 
9ab2b8f
ce78f1b
9ab2b8f
 
ce78f1b
9ab2b8f
ce78f1b
 
 
 
9ab2b8f
ce78f1b
9ab2b8f
 
ce78f1b
9ab2b8f
 
ce78f1b
9ab2b8f
 
 
ce78f1b
 
 
 
9ab2b8f
ce78f1b
 
 
 
9ab2b8f
ce78f1b
 
 
 
9ab2b8f
ce78f1b
 
 
 
 
 
 
 
 
 
 
9ab2b8f
 
 
 
269a39e
9ab2b8f
 
269a39e
9ab2b8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce78f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import gradio as gr
import torch
import os

device = "cuda"

model_name = "mistralai/mathstral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
                torch_dtype=torch.float16).to(device)

HF_TOKEN = os.environ['HF_TOKEN']

def format_prompt(message, history):
  prompt = ""
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response} "
  prompt += f"[INST] {message} [/INST]"
  return prompt

@spaces.GPU
def generate(prompt, history,
             max_new_tokens=1024,
             repetition_penalty=1.2):
    
    formatted_prompt = format_prompt(prompt, history)
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(tokenizer)
    generate_kwargs = dict(
            inputs,
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
    )

    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    text = ''
    n = len('<s>') + len(formatted_prompt)
    for word in streamer:
        text += word
        yield text[n:]
    return text[n:]


additional_inputs=[
    gr.Slider(
        label="Max new tokens",
        value=1024,
        minimum=0,
        maximum=4096,
        step=256,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    ),
]

css = """
  #mkd {
    height: 500px;
    overflow: auto;
    border: 1px solid #ccc;
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1><center>Mathstral Test</center><h1>")
    gr.HTML("<h3><center>Dans cette démo, vous pouvez poser des questions mathématiques et scientifiques à Mathstral. 🧮</center><h3>")
    gr.ChatInterface(
        generate,
        additional_inputs=additional_inputs,
        theme = gr.themes.Soft(),
        cache_examples=False,
        examples=[ [l.strip()] for l in open("exercices.md").readlines()],
        chatbot = gr.Chatbot(
            latex_delimiters=[
                {"left" : "$$", "right": "$$", "display": True },
                {"left" : "\\[", "right": "\\]", "display": True },
                {"left" : "\\(", "right": "\\)", "display": False },
                {"left": "$", "right": "$", "display": False }
                ]
            )
    )

demo.queue(max_size=100).launch(debug=True)