arad1367 commited on
Commit
c210132
1 Parent(s): 45685cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -53
app.py CHANGED
@@ -6,17 +6,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
6
  import gradio as gr
7
  from threading import Thread
8
 
9
- # Define constants and configuration
10
  MODEL_LIST = ["mistralai/mathstral-7B-v0.1"]
11
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
12
  MODEL = os.environ.get("MODEL_ID")
13
 
 
 
14
  PLACEHOLDER = """
15
  <center>
16
- <p>MathΣtral - Your Math advisor</p>
17
- <p>Hi! I'm MisMath. A Math advisor. My model is based on mathstral-7B-v0.1. Feel free to ask your questions</p>
18
- <p>Mathstral 7B is a model specializing in mathematical and scientific tasks, based on Mistral 7B.</p>
19
- <p>mathstral-7B-v0.1 is the first Mathstral model</p>
20
  </center>
21
  """
22
 
@@ -27,86 +25,78 @@ CSS = """
27
  background: black !important;
28
  border-radius: 100vh !important;
29
  }
30
- h1 {
31
  text-align: center;
32
- font-size: 2em;
33
- color: #333;
34
  }
35
  """
36
 
37
- TITLE = "<h1><center>MathΣtral - Your Math advisor</center></h1>"
38
-
39
- device = "cuda" # for GPU usage or "cpu" for CPU usage
40
 
41
- # Configuration for model quantization
42
  quantization_config = BitsAndBytesConfig(
43
  load_in_4bit=True,
44
  bnb_4bit_compute_dtype=torch.bfloat16,
45
  bnb_4bit_use_double_quant=True,
46
- bnb_4bit_quant_type="nf4"
47
- )
48
 
49
- # Initialize tokenizer and model
50
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
51
  model = AutoModelForCausalLM.from_pretrained(
52
  MODEL,
53
  torch_dtype=torch.bfloat16,
54
  device_map="auto",
55
- quantization_config=quantization_config
56
- )
57
 
58
- # Define the chat streaming function
59
  @spaces.GPU()
60
  def stream_chat(
61
- message: str,
62
  history: list,
63
  system_prompt: str,
64
- temperature: float = 0.8,
65
- max_new_tokens: int = 1024,
66
- top_p: float = 1.0,
67
- top_k: int = 20,
68
  penalty: float = 1.2,
69
  ):
70
- # Prepare the conversation context
71
- conversation_text = system_prompt + "\n"
72
- for _, answer in history:
73
- conversation_text += f"MisMath: {answer}\n"
74
-
75
- conversation_text += f"User: {message}\nMisMath:"
76
-
77
- # Tokenize the conversation text
78
- input_ids = tokenizer(conversation_text, return_tensors="pt").input_ids.to(model.device)
79
-
 
 
 
 
 
 
80
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
81
-
82
  generate_kwargs = dict(
83
- input_ids=input_ids,
84
- max_new_tokens=max_new_tokens,
85
- do_sample=False if temperature == 0 else True,
86
- top_p=top_p,
87
- top_k=top_k,
88
- temperature=temperature,
89
- eos_token_id=[128001, 128008, 128009],
90
  streamer=streamer,
91
  )
92
 
93
  with torch.no_grad():
94
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
95
  thread.start()
96
-
97
  buffer = ""
98
- final_output = ""
99
  for new_text in streamer:
100
  buffer += new_text
101
- # Extract only the final response after the last "MisMath:"
102
- if "MisMath:" in buffer:
103
- final_output = buffer.split("MisMath:")[-1].strip()
104
- yield final_output
105
 
106
- # Define the Gradio chatbot component
107
  chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER)
108
 
109
- # Define the footer with links
110
  footer = """
111
  <div style="text-align: center; margin-top: 20px;">
112
  <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
@@ -117,7 +107,6 @@ footer = """
117
  </div>
118
  """
119
 
120
- # Create and launch the Gradio interface
121
  with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
122
  gr.HTML(TITLE)
123
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
@@ -128,7 +117,7 @@ with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
128
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
129
  additional_inputs=[
130
  gr.Textbox(
131
- value="You are a helpful assistant for Math questions and complex calculations and programming and your name is MisMath",
132
  label="System Prompt",
133
  render=False,
134
  ),
@@ -183,6 +172,6 @@ with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
183
  )
184
  gr.HTML(footer)
185
 
186
- # Launch the application
187
  if __name__ == "__main__":
188
- demo.launch()
 
6
  import gradio as gr
7
  from threading import Thread
8
 
 
9
  MODEL_LIST = ["mistralai/mathstral-7B-v0.1"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = os.environ.get("MODEL_ID")
12
 
13
+ TITLE = "<h1><center>MathΣtral</center></h1>"
14
+
15
  PLACEHOLDER = """
16
  <center>
17
+ <p>MathΣtral - I'm MisMath,Your Math advisor</p>
 
 
 
18
  </center>
19
  """
20
 
 
25
  background: black !important;
26
  border-radius: 100vh !important;
27
  }
28
+ h3 {
29
  text-align: center;
 
 
30
  }
31
  """
32
 
33
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
 
 
34
 
 
35
  quantization_config = BitsAndBytesConfig(
36
  load_in_4bit=True,
37
  bnb_4bit_compute_dtype=torch.bfloat16,
38
  bnb_4bit_use_double_quant=True,
39
+ bnb_4bit_quant_type="nf4")
 
40
 
 
41
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
42
  model = AutoModelForCausalLM.from_pretrained(
43
  MODEL,
44
  torch_dtype=torch.bfloat16,
45
  device_map="auto",
46
+ quantization_config=quantization_config)
 
47
 
 
48
  @spaces.GPU()
49
  def stream_chat(
50
+ message: str,
51
  history: list,
52
  system_prompt: str,
53
+ temperature: float = 0.8,
54
+ max_new_tokens: int = 1024,
55
+ top_p: float = 1.0,
56
+ top_k: int = 20,
57
  penalty: float = 1.2,
58
  ):
59
+ print(f'message: {message}')
60
+ print(f'history: {history}')
61
+
62
+ conversation = [
63
+ {"role": "system", "content": system_prompt}
64
+ ]
65
+ for prompt, answer in history:
66
+ conversation.extend([
67
+ {"role": "user", "content": prompt},
68
+ {"role": "assistant", "content": answer},
69
+ ])
70
+
71
+ conversation.append({"role": "user", "content": message})
72
+
73
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
74
+
75
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
76
+
77
  generate_kwargs = dict(
78
+ input_ids=input_ids,
79
+ max_new_tokens = max_new_tokens,
80
+ do_sample = False if temperature == 0 else True,
81
+ top_p = top_p,
82
+ top_k = top_k,
83
+ temperature = temperature,
84
+ eos_token_id=[128001,128008,128009],
85
  streamer=streamer,
86
  )
87
 
88
  with torch.no_grad():
89
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
90
  thread.start()
91
+
92
  buffer = ""
 
93
  for new_text in streamer:
94
  buffer += new_text
95
+ yield buffer
 
 
 
96
 
97
+
98
  chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER)
99
 
 
100
  footer = """
101
  <div style="text-align: center; margin-top: 20px;">
102
  <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
 
107
  </div>
108
  """
109
 
 
110
  with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
111
  gr.HTML(TITLE)
112
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
 
117
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
118
  additional_inputs=[
119
  gr.Textbox(
120
+ value="You are a helpful assistant in mathematical and scientific tasks",
121
  label="System Prompt",
122
  render=False,
123
  ),
 
172
  )
173
  gr.HTML(footer)
174
 
175
+
176
  if __name__ == "__main__":
177
+ demo.launch()