arad1367 commited on
Commit
13a0ef1
1 Parent(s): c9f2880

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -39
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import time
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
7
  from threading import Thread
8
 
@@ -10,12 +10,15 @@ MODEL_LIST = ["mistralai/mathstral-7B-v0.1"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = os.environ.get("MODEL_ID")
12
 
 
 
13
  PLACEHOLDER = """
14
  <center>
15
  <p>MathΣtral - I'm MisMath,Your Math advisor</p>
16
  </center>
17
  """
18
 
 
19
  CSS = """
20
  .duplicate-button {
21
  margin: auto !important;
@@ -23,36 +26,25 @@ CSS = """
23
  background: black !important;
24
  border-radius: 100vh !important;
25
  }
26
- h1 {
27
  text-align: center;
28
- font-size: 2em;
29
- color: #333;
30
  }
31
  """
32
 
33
- TITLE = "<h1><center>MathΣtral - Your Math advisor</center></h1>"
34
-
35
  device = "cuda" # for GPU usage or "cpu" for CPU usage
36
 
37
- quantization_config = BitsAndBytesConfig(
38
- load_in_4bit=True,
39
- bnb_4bit_compute_dtype=torch.bfloat16,
40
- bnb_4bit_use_double_quant=True,
41
- bnb_4bit_quant_type="nf4")
42
-
43
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
44
  model = AutoModelForCausalLM.from_pretrained(
45
  MODEL,
46
  torch_dtype=torch.bfloat16,
47
  device_map="auto",
48
- quantization_config=quantization_config)
49
 
50
  @spaces.GPU()
51
  def stream_chat(
52
  message: str,
53
- history: list,
54
- system_prompt: str,
55
- temperature: float = 0.8,
56
  max_new_tokens: int = 1024,
57
  top_p: float = 1.0,
58
  top_k: int = 20,
@@ -61,27 +53,28 @@ def stream_chat(
61
  print(f'message: {message}')
62
  print(f'history: {history}')
63
 
64
- # Prepare the conversation as plain text
65
- conversation_text = system_prompt + "\n"
66
  for prompt, answer in history:
67
- conversation_text += f"User: {prompt}\nAssistant: {answer}\n"
68
-
69
- conversation_text += f"User: {message}\n"
 
70
 
71
- # Tokenize the conversation text
72
- input_ids = tokenizer(conversation_text, return_tensors="pt").input_ids.to(model.device)
73
 
 
 
74
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
75
-
76
  generate_kwargs = dict(
77
- input_ids=input_ids,
78
- max_new_tokens=max_new_tokens,
79
- do_sample=False if temperature == 0 else True,
80
- top_p=top_p,
81
- top_k=top_k,
82
- temperature=temperature,
83
- eos_token_id=[128001, 128008, 128009],
84
  streamer=streamer,
 
85
  )
86
 
87
  with torch.no_grad():
@@ -92,8 +85,9 @@ def stream_chat(
92
  for new_text in streamer:
93
  buffer += new_text
94
  yield buffer
 
95
 
96
- chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER)
97
 
98
  footer = """
99
  <div style="text-align: center; margin-top: 20px;">
@@ -114,16 +108,11 @@ with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
114
  fill_height=True,
115
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
116
  additional_inputs=[
117
- gr.Textbox(
118
- value="You are a helpful assistant for Math questions and complex calculations and programming and your name is MisMath",
119
- label="System Prompt",
120
- render=False,
121
- ),
122
  gr.Slider(
123
  minimum=0,
124
  maximum=1,
125
  step=0.1,
126
- value=0.8,
127
  label="Temperature",
128
  render=False,
129
  ),
@@ -168,7 +157,6 @@ with gr.Blocks(css=CSS, theme="Ajaxon6255/Emerald_Isle") as demo:
168
  ],
169
  cache_examples=False,
170
  )
171
- gr.HTML(footer)
172
 
173
 
174
  if __name__ == "__main__":
 
2
  import time
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  from threading import Thread
8
 
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = os.environ.get("MODEL_ID")
12
 
13
+ TITLE = "<h1><center>MathΣtral</center></h1>"
14
+
15
  PLACEHOLDER = """
16
  <center>
17
  <p>MathΣtral - I'm MisMath,Your Math advisor</p>
18
  </center>
19
  """
20
 
21
+
22
  CSS = """
23
  .duplicate-button {
24
  margin: auto !important;
 
26
  background: black !important;
27
  border-radius: 100vh !important;
28
  }
29
+ h3 {
30
  text-align: center;
 
 
31
  }
32
  """
33
 
 
 
34
  device = "cuda" # for GPU usage or "cpu" for CPU usage
35
 
 
 
 
 
 
 
36
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
  model = AutoModelForCausalLM.from_pretrained(
38
  MODEL,
39
  torch_dtype=torch.bfloat16,
40
  device_map="auto",
41
+ ignore_mismatched_sizes=True)
42
 
43
  @spaces.GPU()
44
  def stream_chat(
45
  message: str,
46
+ history: list,
47
+ temperature: float = 0.3,
 
48
  max_new_tokens: int = 1024,
49
  top_p: float = 1.0,
50
  top_k: int = 20,
 
53
  print(f'message: {message}')
54
  print(f'history: {history}')
55
 
56
+ conversation = []
 
57
  for prompt, answer in history:
58
+ conversation.extend([
59
+ {"role": "user", "content": prompt},
60
+ {"role": "assistant", "content": answer},
61
+ ])
62
 
63
+ conversation.append({"role": "user", "content": message})
 
64
 
65
+ input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
66
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
67
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
68
+
69
  generate_kwargs = dict(
70
+ input_ids=inputs,
71
+ max_new_tokens = max_new_tokens,
72
+ do_sample = False if temperature == 0 else True,
73
+ top_p = top_p,
74
+ top_k = top_k,
75
+ temperature = temperature,
 
76
  streamer=streamer,
77
+ pad_token_id = 10,
78
  )
79
 
80
  with torch.no_grad():
 
85
  for new_text in streamer:
86
  buffer += new_text
87
  yield buffer
88
+
89
 
90
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
91
 
92
  footer = """
93
  <div style="text-align: center; margin-top: 20px;">
 
108
  fill_height=True,
109
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
110
  additional_inputs=[
 
 
 
 
 
111
  gr.Slider(
112
  minimum=0,
113
  maximum=1,
114
  step=0.1,
115
+ value=0.3,
116
  label="Temperature",
117
  render=False,
118
  ),
 
157
  ],
158
  cache_examples=False,
159
  )
 
160
 
161
 
162
  if __name__ == "__main__":