arad1367 commited on
Commit
dcf5029
1 Parent(s): 13a0ef1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -37
app.py CHANGED
@@ -2,23 +2,25 @@ import os
2
  import time
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  from threading import Thread
8
 
 
9
  MODEL_LIST = ["mistralai/mathstral-7B-v0.1"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = os.environ.get("MODEL_ID")
12
 
13
- TITLE = "<h1><center>MathΣtral</center></h1>"
14
-
15
  PLACEHOLDER = """
16
  <center>
17
- <p>MathΣtral - I'm MisMath,Your Math advisor</p>
 
 
 
 
18
  </center>
19
  """
20
 
21
-
22
  CSS = """
23
  .duplicate-button {
24
  margin: auto !important;
@@ -26,69 +28,87 @@ CSS = """
26
  background: black !important;
27
  border-radius: 100vh !important;
28
  }
29
- h3 {
30
  text-align: center;
 
 
31
  }
32
  """
33
 
34
- device = "cuda" # for GPU usage or "cpu" for CPU usage
 
 
35
 
 
 
 
 
 
 
 
 
 
36
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
  model = AutoModelForCausalLM.from_pretrained(
38
  MODEL,
39
  torch_dtype=torch.bfloat16,
40
  device_map="auto",
41
- ignore_mismatched_sizes=True)
 
42
 
 
43
  @spaces.GPU()
44
  def stream_chat(
45
- message: str,
46
- history: list,
47
- temperature: float = 0.3,
48
- max_new_tokens: int = 1024,
49
- top_p: float = 1.0,
50
- top_k: int = 20,
 
51
  penalty: float = 1.2,
52
  ):
53
  print(f'message: {message}')
54
  print(f'history: {history}')
55
 
56
- conversation = []
 
57
  for prompt, answer in history:
58
- conversation.extend([
59
- {"role": "user", "content": prompt},
60
- {"role": "assistant", "content": answer},
61
- ])
62
 
63
- conversation.append({"role": "user", "content": message})
 
64
 
65
- input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
66
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
67
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
68
-
69
  generate_kwargs = dict(
70
- input_ids=inputs,
71
- max_new_tokens = max_new_tokens,
72
- do_sample = False if temperature == 0 else True,
73
- top_p = top_p,
74
- top_k = top_k,
75
- temperature = temperature,
 
76
  streamer=streamer,
77
- pad_token_id = 10,
78
  )
79
 
80
  with torch.no_grad():
81
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
82
  thread.start()
83
-
84
  buffer = ""
85
  for new_text in streamer:
86
  buffer += new_text
87
- yield buffer
 
 
88
 
89
-
90
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
91
 
 
92
  footer = """
93
  <div style="text-align: center; margin-top: 20px;">
94
  <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
@@ -99,6 +119,7 @@ footer = """
99
  </div>
100
  """
101
 
 
102
  with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
103
  gr.HTML(TITLE)
104
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
@@ -108,11 +129,16 @@ with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
108
  fill_height=True,
109
  additional_inputs_accordion=gr.Accordion(label="��️ Parameters", open=False, render=False),
110
  additional_inputs=[
 
 
 
 
 
111
  gr.Slider(
112
  minimum=0,
113
  maximum=1,
114
  step=0.1,
115
- value=0.3,
116
  label="Temperature",
117
  render=False,
118
  ),
@@ -157,7 +183,8 @@ with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
157
  ],
158
  cache_examples=False,
159
  )
 
160
 
161
-
162
  if __name__ == "__main__":
163
- demo.launch()
 
2
  import time
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
7
  from threading import Thread
8
 
9
+ # Define constants and configuration
10
  MODEL_LIST = ["mistralai/mathstral-7B-v0.1"]
11
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
12
  MODEL = os.environ.get("MODEL_ID")
13
 
 
 
14
  PLACEHOLDER = """
15
  <center>
16
+ <p>MathΣtral - Your Math advisor</p>
17
+ <p>Hi! I'm MisMath. A Math advisor. My model is based on mathstral-7B-v0.1. Feel free to ask your questions</p>
18
+ <p>Mathstral 7B is a model specializing in mathematical and scientific tasks, based on Mistral 7B.</p>
19
+ <p>mathstral-7B-v0.1 is the first Mathstral model</p>
20
+ <img src="Mistral.png" alt="MathStral Model" style="width:300px;height:200px;">
21
  </center>
22
  """
23
 
 
24
  CSS = """
25
  .duplicate-button {
26
  margin: auto !important;
 
28
  background: black !important;
29
  border-radius: 100vh !important;
30
  }
31
+ h1 {
32
  text-align: center;
33
+ font-size: 2em;
34
+ color: #333;
35
  }
36
  """
37
 
38
+ TITLE = "<h1><center>MathΣtral - Your Math advisor</center></h1>"
39
+
40
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
41
 
42
+ # Configuration for model quantization
43
+ quantization_config = BitsAndBytesConfig(
44
+ load_in_4bit=True,
45
+ bnb_4bit_compute_dtype=torch.bfloat16,
46
+ bnb_4bit_use_double_quant=True,
47
+ bnb_4bit_quant_type="nf4"
48
+ )
49
+
50
+ # Initialize tokenizer and model
51
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
52
  model = AutoModelForCausalLM.from_pretrained(
53
  MODEL,
54
  torch_dtype=torch.bfloat16,
55
  device_map="auto",
56
+ quantization_config=quantization_config
57
+ )
58
 
59
+ # Define the chat streaming function
60
  @spaces.GPU()
61
  def stream_chat(
62
+ message: str,
63
+ history: list,
64
+ system_prompt: str,
65
+ temperature: float = 0.8,
66
+ max_new_tokens: int = 1024,
67
+ top_p: float = 1.0,
68
+ top_k: int = 20,
69
  penalty: float = 1.2,
70
  ):
71
  print(f'message: {message}')
72
  print(f'history: {history}')
73
 
74
+ # Prepare the conversation context
75
+ conversation_text = system_prompt + "\n"
76
  for prompt, answer in history:
77
+ conversation_text += f"User: {prompt}\nAssistant: {answer}\n"
78
+
79
+ conversation_text += f"User: {message}\nAssistant:"
 
80
 
81
+ # Tokenize the conversation text
82
+ input_ids = tokenizer(conversation_text, return_tensors="pt").input_ids.to(model.device)
83
 
 
 
84
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
85
+
86
  generate_kwargs = dict(
87
+ input_ids=input_ids,
88
+ max_new_tokens=max_new_tokens,
89
+ do_sample=False if temperature == 0 else True,
90
+ top_p=top_p,
91
+ top_k=top_k,
92
+ temperature=temperature,
93
+ eos_token_id=[128001, 128008, 128009],
94
  streamer=streamer,
 
95
  )
96
 
97
  with torch.no_grad():
98
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
99
  thread.start()
100
+
101
  buffer = ""
102
  for new_text in streamer:
103
  buffer += new_text
104
+ # Clean the buffer to remove unwanted prefixes
105
+ cleaned_text = buffer.split("Assistant:")[-1].strip()
106
+ yield cleaned_text
107
 
108
+ # Define the Gradio chatbot component
109
+ chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER)
110
 
111
+ # Define the footer with links
112
  footer = """
113
  <div style="text-align: center; margin-top: 20px;">
114
  <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
 
119
  </div>
120
  """
121
 
122
+ # Create and launch the Gradio interface
123
  with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
124
  gr.HTML(TITLE)
125
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
 
129
  fill_height=True,
130
  additional_inputs_accordion=gr.Accordion(label="��️ Parameters", open=False, render=False),
131
  additional_inputs=[
132
+ gr.Textbox(
133
+ value="You are a helpful assistant for Math questions and complex calculations and programming and your name is MisMath",
134
+ label="System Prompt",
135
+ render=False,
136
+ ),
137
  gr.Slider(
138
  minimum=0,
139
  maximum=1,
140
  step=0.1,
141
+ value=0.8,
142
  label="Temperature",
143
  render=False,
144
  ),
 
183
  ],
184
  cache_examples=False,
185
  )
186
+ gr.HTML(footer)
187
 
188
+ # Launch the application
189
  if __name__ == "__main__":
190
+ demo.launch()