AlekseyKorshuk commited on
Commit
fb9f0a9
1 Parent(s): f3d785b
Files changed (2) hide show
  1. app.py +118 -72
  2. models/chatml.py +17 -0
app.py CHANGED
@@ -1,15 +1,14 @@
1
  import gradio as gr
2
- import random
3
- import time
4
  import os
5
  import firebase_admin
6
  from firebase_admin import db
7
  from firebase_admin import firestore
8
  from conversation import Conversation
9
  from models.base import BaseModel
10
- import requests
11
  import json
12
 
 
 
13
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
14
  FIREBASE_URL = os.environ.get("FIREBASE_URL")
15
  CERTIFICATE = json.loads(os.environ.get("CERTIFICATE"))
@@ -21,12 +20,12 @@ models = [
21
  endpoint="mpt-7b",
22
  namespace="tenant-chairesearch-test",
23
  generation_params={
24
- 'temperature': 1.0,
25
  'repetition_penalty': 1.0,
26
  'max_new_tokens': 128,
27
- 'top_k': 1,
28
- 'top_p': 1.0,
29
- 'do_sample': False,
30
  'eos_token_id': 187,
31
  }
32
  ),
@@ -35,15 +34,57 @@ models = [
35
  endpoint="mpt-7b-storywriter",
36
  namespace="tenant-chairesearch-test",
37
  generation_params={
38
- 'temperature': 1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  'repetition_penalty': 1.0,
40
  'max_new_tokens': 128,
41
- 'top_k': 1,
42
- 'top_p': 1.0,
43
- 'do_sample': False,
44
  'eos_token_id': 187,
45
  }
46
- )
47
  ]
48
  model_mapping = {model.name: model for model in models}
49
 
@@ -115,65 +156,70 @@ def get_bot_profile(bot_config):
115
 
116
 
117
  with gr.Blocks() as demo:
118
- default_bot_id = "_bot_1ec22e2e-3e07-42c7-8508-dfa0278c1b33"
119
- bot_config = download_bot_config(default_bot_id)
120
- user_state = gr.State(
121
- bot_config
122
- )
123
- with gr.Row():
124
- bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
125
- reload_bot_button = gr.Button("Reload bot")
126
-
127
- bot_profile = gr.HTML(get_bot_profile(bot_config))
128
-
129
- first_message = (None, bot_config["firstMessage"])
130
- chatbot = gr.Chatbot([first_message])
131
-
132
- msg = gr.Textbox(label="Message", value="Hi there!")
133
- with gr.Row():
134
- clear = gr.Button("Clear")
135
- regenerate = gr.Button("Regenerate")
136
- values = list(model_mapping.keys())
137
- model_tag = gr.Dropdown(values, value=values[0], label="Model version")
138
-
139
-
140
- def respond(message, chat_history, user_state, model_tag):
141
- conv = Conversation(user_state)
142
- conv.set_chat_history(chat_history)
143
- conv.add_user_message(message)
144
- model = model_mapping[model_tag]
145
- bot_message = model.generate_response(conv)
146
- chat_history.append(
147
- (message, bot_message)
148
- )
149
- return "", chat_history
150
-
151
-
152
- def clear_chat(chat_history, user_state):
153
- chat_history = [(None, user_state["firstMessage"])]
154
- return "", chat_history
155
-
156
-
157
- def regenerate_response(chat_history, user_state, model_tag):
158
- last_row = chat_history.pop(-1)
159
- chat_history.append((last_row[0], None))
160
- model = model_mapping[model_tag]
161
- conv = Conversation(user_state)
162
- conv.set_chat_history(chat_history)
163
- bot_message = model.generate_response(conv)
164
- chat_history[-1] = (last_row[0], bot_message)
165
- return "", chat_history
166
-
167
-
168
- def reload_bot(bot_id, bot_profile, chat_history):
169
- bot_config = download_bot_config(bot_id)
170
- bot_profile = get_bot_profile(bot_config)
171
- return bot_profile, [(None, bot_config["firstMessage"])], bot_config
172
-
173
-
174
- msg.submit(respond, [msg, chatbot, user_state, model_tag], [msg, chatbot], queue=False)
175
- clear.click(clear_chat, [chatbot, user_state], [msg, chatbot], queue=False)
176
- regenerate.click(regenerate_response, [chatbot, user_state, model_tag], [msg, chatbot], queue=False)
177
- reload_bot_button.click(reload_bot, [bot_id, bot_profile, chatbot], [bot_profile, chatbot, user_state], queue=False)
 
 
 
 
 
178
 
179
  demo.launch(enable_queue=False)
 
1
  import gradio as gr
 
 
2
  import os
3
  import firebase_admin
4
  from firebase_admin import db
5
  from firebase_admin import firestore
6
  from conversation import Conversation
7
  from models.base import BaseModel
 
8
  import json
9
 
10
+ from models.chatml import ChatML
11
+
12
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
13
  FIREBASE_URL = os.environ.get("FIREBASE_URL")
14
  CERTIFICATE = json.loads(os.environ.get("CERTIFICATE"))
 
20
  endpoint="mpt-7b",
21
  namespace="tenant-chairesearch-test",
22
  generation_params={
23
+ 'temperature': 0.72,
24
  'repetition_penalty': 1.0,
25
  'max_new_tokens': 128,
26
+ 'top_k': 10,
27
+ 'top_p': 0.9,
28
+ 'do_sample': True,
29
  'eos_token_id': 187,
30
  }
31
  ),
 
34
  endpoint="mpt-7b-storywriter",
35
  namespace="tenant-chairesearch-test",
36
  generation_params={
37
+ 'temperature': 0.72,
38
+ 'repetition_penalty': 1.0,
39
+ 'max_new_tokens': 128,
40
+ 'top_k': 10,
41
+ 'top_p': 0.9,
42
+ 'do_sample': True,
43
+ 'eos_token_id': 187,
44
+ }
45
+ ),
46
+ ChatML(
47
+ name="mosaicml/mpt-7b-chat",
48
+ endpoint="mpt-7b-chat",
49
+ namespace="tenant-chairesearch-test",
50
+ generation_params={
51
+ 'temperature': 0.72,
52
+ 'repetition_penalty': 1.0,
53
+ 'max_new_tokens': 128,
54
+ 'top_k': 10,
55
+ 'top_p': 0.9,
56
+ 'do_sample': True,
57
+ 'eos_token_id': 50278,
58
+ }
59
+ ),
60
+ BaseModel(
61
+ name="togethercomputer/RedPajama-INCITE-Base-7B-v0.1",
62
+ endpoint="redpajama-base-7b",
63
+ namespace="tenant-chairesearch-test",
64
+ generation_params={
65
+ 'temperature': 0.72,
66
+ 'repetition_penalty': 1.0,
67
+ 'max_new_tokens': 128,
68
+ 'top_k': 10,
69
+ 'top_p': 0.9,
70
+ 'do_sample': True,
71
+ 'eos_token_id': 187,
72
+ }
73
+ ),
74
+ BaseModel(
75
+ name="togethercomputer/RedPajama-INCITE-Chat-7B-v0.1",
76
+ endpoint="redpajama-chat-7b",
77
+ namespace="tenant-chairesearch-test",
78
+ generation_params={
79
+ 'temperature': 0.72,
80
  'repetition_penalty': 1.0,
81
  'max_new_tokens': 128,
82
+ 'top_k': 10,
83
+ 'top_p': 0.9,
84
+ 'do_sample': True,
85
  'eos_token_id': 187,
86
  }
87
+ ),
88
  ]
89
  model_mapping = {model.name: model for model in models}
90
 
 
156
 
157
 
158
  with gr.Blocks() as demo:
159
+ with gr.Tabs():
160
+ with gr.TabItem("Playground"):
161
+ default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca"
162
+ bot_config = download_bot_config(default_bot_id)
163
+ user_state = gr.State(
164
+ bot_config
165
+ )
166
+ with gr.Row():
167
+ bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
168
+ reload_bot_button = gr.Button("Reload bot")
169
+
170
+ bot_profile = gr.HTML(get_bot_profile(bot_config))
171
+
172
+ first_message = (None, bot_config["firstMessage"])
173
+ chatbot = gr.Chatbot([first_message])
174
+
175
+ msg = gr.Textbox(label="Message", value="Hi there!")
176
+ with gr.Row():
177
+ send = gr.Button("Send")
178
+ regenerate = gr.Button("Regenerate")
179
+ clear = gr.Button("Clear")
180
+ values = list(model_mapping.keys())
181
+ model_tag = gr.Dropdown(values, value=values[0], label="Model version")
182
+
183
+
184
+ def respond(message, chat_history, user_state, model_tag):
185
+ conv = Conversation(user_state)
186
+ conv.set_chat_history(chat_history)
187
+ conv.add_user_message(message)
188
+ model = model_mapping[model_tag]
189
+ bot_message = model.generate_response(conv)
190
+ chat_history.append(
191
+ (message, bot_message)
192
+ )
193
+ return "", chat_history
194
+
195
+
196
+ def clear_chat(chat_history, user_state):
197
+ chat_history = [(None, user_state["firstMessage"])]
198
+ return "", chat_history
199
+
200
+
201
+ def regenerate_response(chat_history, user_state, model_tag):
202
+ last_row = chat_history.pop(-1)
203
+ chat_history.append((last_row[0], None))
204
+ model = model_mapping[model_tag]
205
+ conv = Conversation(user_state)
206
+ conv.set_chat_history(chat_history)
207
+ bot_message = model.generate_response(conv)
208
+ chat_history[-1] = (last_row[0], bot_message)
209
+ return "", chat_history
210
+
211
+
212
+ def reload_bot(bot_id, bot_profile, chat_history):
213
+ bot_config = download_bot_config(bot_id)
214
+ bot_profile = get_bot_profile(bot_config)
215
+ return bot_profile, [(None, bot_config["firstMessage"])], bot_config
216
+
217
+
218
+ send.click(respond, [msg, chatbot, user_state, model_tag], [msg, chatbot], queue=False)
219
+ msg.submit(respond, [msg, chatbot, user_state, model_tag], [msg, chatbot], queue=False)
220
+ clear.click(clear_chat, [chatbot, user_state], [msg, chatbot], queue=False)
221
+ regenerate.click(regenerate_response, [chatbot, user_state, model_tag], [msg, chatbot], queue=False)
222
+ reload_bot_button.click(reload_bot, [bot_id, bot_profile, chatbot], [bot_profile, chatbot, user_state],
223
+ queue=False)
224
 
225
  demo.launch(enable_queue=False)
models/chatml.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from conversation import Conversation
2
+ from models.base import BaseModel
3
+
4
+
5
+ class ChatML(BaseModel):
6
+
7
+ def _get_prompt(self, conversation: Conversation):
8
+ system_message = "\n".join(
9
+ [conversation.memory, conversation.prompt]
10
+ ).strip()
11
+ prompt = f"<|im_start|>system\n{system_message}<|im_end|>"
12
+
13
+ for message in conversation.messages:
14
+ prompt += f"\n<|im_start|>{message['from']}\n{message['value']}<|im_end|>"
15
+ prompt += f"\n<|im_start|>{conversation.bot_label}\n"
16
+ print(prompt)
17
+ return prompt