MarcdeFalco commited on
Commit
ce78f1b
·
verified ·
1 Parent(s): 6f755ec

Migrate to transformers

Browse files
Files changed (1) hide show
  1. app.py +67 -81
app.py CHANGED
@@ -1,83 +1,55 @@
1
- from huggingface_hub import InferenceClient
 
 
2
  import gradio as gr
 
3
  import os
4
 
5
- API_URL = {
6
- "Mistral" : "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3",
7
- "Mixtral" : "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
8
- "Mathstral" : "https://api-inference.huggingface.co/models/mistralai/mathstral-7B-v0.1",
9
- }
10
 
11
- HF_TOKEN = os.environ['HF_TOKEN']
12
-
13
- mistralClient = InferenceClient(
14
- API_URL["Mistral"],
15
- headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
16
- )
17
 
18
- mixtralClient = InferenceClient(
19
- model = API_URL["Mixtral"],
20
- headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
21
- )
22
-
23
- mathstralClient = InferenceClient(
24
- model = API_URL["Mathstral"],
25
- headers = {"Authorization" : f"Bearer {HF_TOKEN}"},
26
- )
27
 
28
  def format_prompt(message, history):
29
- prompt = "<s>"
30
-
31
  for user_prompt, bot_response in history:
32
  prompt += f"[INST] {user_prompt} [/INST]"
33
- prompt += f" {bot_response}</s> "
34
  prompt += f"[INST] {message} [/INST]"
35
  return prompt
36
 
37
- def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95,
38
- repetition_penalty=1.0, model = "Mathstral"):
39
- # Selecting model to be used
40
- if(model == "Mistral"):
41
- client = mistralClient
42
- elif(model == "Mixstral"):
43
- client = mixtralClient
44
- elif(model == "Mathstral"):
45
- client = mathstralClient
46
-
47
-
48
- temperature = float(temperature) # Generation arguments
49
- if temperature < 1e-2:
50
- temperature = 1e-2
51
-
52
- top_p = float(top_p)
53
 
 
 
 
 
54
  generate_kwargs = dict(
55
- temperature=temperature,
56
- max_new_tokens=max_new_tokens,
57
- top_p=top_p,
58
- repetition_penalty=repetition_penalty,
59
- do_sample=True,
60
- seed=42,
61
  )
62
-
63
- formatted_prompt = format_prompt(prompt, history)
64
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
65
- output = ""
66
- for response in stream:
67
- output += response.token.text
68
- yield output
69
- return output
 
 
 
70
 
71
  additional_inputs=[
72
- gr.Slider(
73
- label="Temperature",
74
- value=0.3,
75
- minimum=0.0,
76
- maximum=1.0,
77
- step=0.1,
78
- interactive=True,
79
- info="Higher values produce more diverse outputs",
80
- ),
81
  gr.Slider(
82
  label="Max new tokens",
83
  value=1024,
@@ -87,15 +59,6 @@ additional_inputs=[
87
  interactive=True,
88
  info="The maximum numbers of new tokens",
89
  ),
90
- gr.Slider(
91
- label="Top-p (nucleus sampling)",
92
- value=0.90,
93
- minimum=0.0,
94
- maximum=1,
95
- step=0.05,
96
- interactive=True,
97
- info="Higher values sample more low-probability tokens",
98
- ),
99
  gr.Slider(
100
  label="Repetition penalty",
101
  value=1.2,
@@ -105,15 +68,6 @@ additional_inputs=[
105
  interactive=True,
106
  info="Penalize repeated tokens",
107
  ),
108
- gr.Dropdown(
109
- choices = ["Mistral","Mixtral", "Mathstral"],
110
- value = "Mathstral",
111
- label = "Le modèle à utiliser",
112
- interactive=True,
113
- info = "Mistral : pour des conversations génériques, "+
114
- "Mixtral : conversations plus rapides et plus performantes, "+
115
- "Mathstral : raisonnement mathématiques et scientifique"
116
- ),
117
  ]
118
 
119
  css = """
@@ -144,3 +98,35 @@ with gr.Blocks(css=css) as demo:
144
  )
145
 
146
  demo.queue(max_size=100).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
+ from threading import Thread
4
  import gradio as gr
5
+ import torch
6
  import os
7
 
8
+ device = "cuda"
 
 
 
 
9
 
10
+ model_name = "mistralai/mathstral-7B-v0.1"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name,
13
+ torch_dtype=torch.float16).to(device)
 
 
14
 
15
+ HF_TOKEN = os.environ['HF_TOKEN']
 
 
 
 
 
 
 
 
16
 
17
  def format_prompt(message, history):
18
+ prompt = ""
 
19
  for user_prompt, bot_response in history:
20
  prompt += f"[INST] {user_prompt} [/INST]"
21
+ prompt += f" {bot_response} "
22
  prompt += f"[INST] {message} [/INST]"
23
  return prompt
24
 
25
+ @spaces.GPU
26
+ def generate(prompt, history,
27
+ max_new_tokens=1024,
28
+ repetition_penalty=1.2):
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ formatted_prompt = format_prompt(prompt, history)
31
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
32
+
33
+ streamer = TextIteratorStreamer(tokenizer)
34
  generate_kwargs = dict(
35
+ inputs,
36
+ streamer=streamer,
37
+ max_new_tokens=max_new_tokens,
38
+ repetition_penalty=repetition_penalty,
 
 
39
  )
40
+
41
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
42
+ thread.start()
43
+
44
+ text = ''
45
+ n = len('<s>') + len(formatted_prompt)
46
+ for word in streamer:
47
+ text += word
48
+ yield text[n:]
49
+ return text[n:]
50
+
51
 
52
  additional_inputs=[
 
 
 
 
 
 
 
 
 
53
  gr.Slider(
54
  label="Max new tokens",
55
  value=1024,
 
59
  interactive=True,
60
  info="The maximum numbers of new tokens",
61
  ),
 
 
 
 
 
 
 
 
 
62
  gr.Slider(
63
  label="Repetition penalty",
64
  value=1.2,
 
68
  interactive=True,
69
  info="Penalize repeated tokens",
70
  ),
 
 
 
 
 
 
 
 
 
71
  ]
72
 
73
  css = """
 
98
  )
99
 
100
  demo.queue(max_size=100).launch(debug=True)
101
+ : raisonnement mathématiques et scientifique"
102
+ ),
103
+ ]
104
+
105
+ css = """
106
+ #mkd {
107
+ height: 500px;
108
+ overflow: auto;
109
+ border: 1px solid #ccc;
110
+ }
111
+ """
112
+
113
+ with gr.Blocks(css=css) as demo:
114
+ gr.HTML("<h1><center>Mathstral Test</center><h1>")
115
+ gr.HTML("<h3><center>Dans cette démo, vous pouvez poser des questions mathématiques et scientifiques à Mathstral. 🧮</center><h3>")
116
+ gr.ChatInterface(
117
+ generate,
118
+ additional_inputs=additional_inputs,
119
+ theme = gr.themes.Soft(),
120
+ cache_examples=False,
121
+ examples=[ [l.strip()] for l in open("exercices.md").readlines()],
122
+ chatbot = gr.Chatbot(
123
+ latex_delimiters=[
124
+ {"left" : "$$", "right": "$$", "display": True },
125
+ {"left" : "\\[", "right": "\\]", "display": True },
126
+ {"left" : "\\(", "right": "\\)", "display": False },
127
+ {"left": "$", "right": "$", "display": False }
128
+ ]
129
+ )
130
+ )
131
+
132
+ demo.queue(max_size=100).launch(debug=True)