AlekseyKorshuk commited on
Commit
cb80e0b
·
1 Parent(s): fb9f0a9
app.py CHANGED
@@ -7,13 +7,29 @@ from conversation import Conversation
7
  from models.base import BaseModel
8
  import json
9
 
 
 
 
 
10
  from models.chatml import ChatML
 
 
 
 
 
11
 
 
 
 
 
12
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
13
  FIREBASE_URL = os.environ.get("FIREBASE_URL")
14
  CERTIFICATE = json.loads(os.environ.get("CERTIFICATE"))
15
  API_BASE_PATH = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
16
 
 
 
 
17
  models = [
18
  BaseModel(
19
  name="mosaicml/mpt-7b",
@@ -78,7 +94,7 @@ models = [
78
  generation_params={
79
  'temperature': 0.72,
80
  'repetition_penalty': 1.0,
81
- 'max_new_tokens': 128,
82
  'top_k': 10,
83
  'top_p': 0.9,
84
  'do_sample': True,
@@ -158,68 +174,10 @@ def get_bot_profile(bot_config):
158
  with gr.Blocks() as demo:
159
  with gr.Tabs():
160
  with gr.TabItem("Playground"):
161
- default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca"
162
- bot_config = download_bot_config(default_bot_id)
163
- user_state = gr.State(
164
- bot_config
165
- )
166
- with gr.Row():
167
- bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
168
- reload_bot_button = gr.Button("Reload bot")
169
-
170
- bot_profile = gr.HTML(get_bot_profile(bot_config))
171
-
172
- first_message = (None, bot_config["firstMessage"])
173
- chatbot = gr.Chatbot([first_message])
174
-
175
- msg = gr.Textbox(label="Message", value="Hi there!")
176
- with gr.Row():
177
- send = gr.Button("Send")
178
- regenerate = gr.Button("Regenerate")
179
- clear = gr.Button("Clear")
180
- values = list(model_mapping.keys())
181
- model_tag = gr.Dropdown(values, value=values[0], label="Model version")
182
-
183
-
184
- def respond(message, chat_history, user_state, model_tag):
185
- conv = Conversation(user_state)
186
- conv.set_chat_history(chat_history)
187
- conv.add_user_message(message)
188
- model = model_mapping[model_tag]
189
- bot_message = model.generate_response(conv)
190
- chat_history.append(
191
- (message, bot_message)
192
- )
193
- return "", chat_history
194
-
195
-
196
- def clear_chat(chat_history, user_state):
197
- chat_history = [(None, user_state["firstMessage"])]
198
- return "", chat_history
199
-
200
-
201
- def regenerate_response(chat_history, user_state, model_tag):
202
- last_row = chat_history.pop(-1)
203
- chat_history.append((last_row[0], None))
204
- model = model_mapping[model_tag]
205
- conv = Conversation(user_state)
206
- conv.set_chat_history(chat_history)
207
- bot_message = model.generate_response(conv)
208
- chat_history[-1] = (last_row[0], bot_message)
209
- return "", chat_history
210
-
211
-
212
- def reload_bot(bot_id, bot_profile, chat_history):
213
- bot_config = download_bot_config(bot_id)
214
- bot_profile = get_bot_profile(bot_config)
215
- return bot_profile, [(None, bot_config["firstMessage"])], bot_config
216
-
217
-
218
- send.click(respond, [msg, chatbot, user_state, model_tag], [msg, chatbot], queue=False)
219
- msg.submit(respond, [msg, chatbot, user_state, model_tag], [msg, chatbot], queue=False)
220
- clear.click(clear_chat, [chatbot, user_state], [msg, chatbot], queue=False)
221
- regenerate.click(regenerate_response, [chatbot, user_state, model_tag], [msg, chatbot], queue=False)
222
- reload_bot_button.click(reload_bot, [bot_id, bot_profile, chatbot], [bot_profile, chatbot, user_state],
223
- queue=False)
224
 
225
  demo.launch(enable_queue=False)
 
7
  from models.base import BaseModel
8
  import json
9
 
10
+ from tabs.arena_battle import get_tab_arena_battle
11
+ from tabs.arena_side_by_side import get_tab_arena_side_by_side
12
+ from tabs.playground import get_tab_playground
13
+
14
  from models.chatml import ChatML
15
+ import json
16
+ import os
17
+
18
+ import gspread
19
+ from oauth2client.service_account import ServiceAccountCredentials
20
 
21
+ scope = ["https://spreadsheets.google.com/feeds", 'https://www.googleapis.com/auth/spreadsheets',
22
+ "https://www.googleapis.com/auth/drive.file", "https://www.googleapis.com/auth/drive"]
23
+
24
+ GOOGLE_SHEETS_CERTIFICATE = json.loads(os.environ.get("GOOGLE_SHEETS_CERTIFICATE"))
25
  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
26
  FIREBASE_URL = os.environ.get("FIREBASE_URL")
27
  CERTIFICATE = json.loads(os.environ.get("CERTIFICATE"))
28
  API_BASE_PATH = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
29
 
30
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(GOOGLE_SHEETS_CERTIFICATE, scope)
31
+ client = gspread.authorize(creds)
32
+
33
  models = [
34
  BaseModel(
35
  name="mosaicml/mpt-7b",
 
94
  generation_params={
95
  'temperature': 0.72,
96
  'repetition_penalty': 1.0,
97
+ 'max_new_tokens': 64,
98
  'top_k': 10,
99
  'top_p': 0.9,
100
  'do_sample': True,
 
174
  with gr.Blocks() as demo:
175
  with gr.Tabs():
176
  with gr.TabItem("Playground"):
177
+ get_tab_playground(download_bot_config, get_bot_profile, model_mapping)
178
+ with gr.TabItem("Chatbot Arena (battle)"):
179
+ get_tab_arena_battle(download_bot_config, get_bot_profile, model_mapping, client)
180
+ with gr.TabItem("Chatbot Arena (side-by-side)"):
181
+ get_tab_arena_side_by_side(download_bot_config, get_bot_profile, model_mapping, client)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  demo.launch(enable_queue=False)
conversation.py CHANGED
@@ -1,4 +1,5 @@
1
  class Conversation:
 
2
  memory: str
3
  prompt: str
4
  bot_label: str
@@ -6,6 +7,7 @@ class Conversation:
6
  messages: list
7
 
8
  def __init__(self, bot_config):
 
9
  self.memory = bot_config.get("memory", "")
10
  self.prompt = bot_config.get("prompt", "")
11
  self.bot_label = bot_config.get("botLabel", "Character")
 
1
  class Conversation:
2
+ bot_id: str
3
  memory: str
4
  prompt: str
5
  bot_label: str
 
7
  messages: list
8
 
9
  def __init__(self, bot_config):
10
+ self.bot_id = bot_config.get("bot_id")
11
  self.memory = bot_config.get("memory", "")
12
  self.prompt = bot_config.get("prompt", "")
13
  self.bot_label = bot_config.get("botLabel", "Character")
models/base.py CHANGED
@@ -16,28 +16,28 @@ class BaseModel:
16
  self.namespace = namespace
17
  self.generation_params = generation_params
18
 
19
- def generate_response(self, conversation):
20
  prompt = self._get_prompt(conversation)
21
- response = self._get_response(prompt)
22
  return response
23
 
24
  def _get_prompt(self, conversation: Conversation):
25
- print(conversation.__dict__)
26
  prompt = "\n".join(
27
  [conversation.memory, conversation.prompt]
28
  ).strip()
29
-
30
  for message in conversation.messages:
31
  prompt += f"\n{message['from'].strip()}: {message['value'].strip()}"
32
  prompt += f"\n{conversation.bot_label}:"
33
- print(prompt)
34
  return prompt
35
 
36
- def _get_response(self, text):
37
  api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
38
  api = api.format(self.endpoint, self.namespace)
39
-
40
- payload = {'instances': [text], "parameters": self.generation_params}
 
 
 
41
  resp = requests.post(api, json=payload, timeout=600)
42
  assert resp.status_code == 200, (resp.content, resp.status_code)
43
  return resp.json()["predictions"][0].strip()
 
16
  self.namespace = namespace
17
  self.generation_params = generation_params
18
 
19
+ def generate_response(self, conversation, custom_generation_params=None):
20
  prompt = self._get_prompt(conversation)
21
+ response = self._get_response(prompt, custom_generation_params)
22
  return response
23
 
24
  def _get_prompt(self, conversation: Conversation):
 
25
  prompt = "\n".join(
26
  [conversation.memory, conversation.prompt]
27
  ).strip()
 
28
  for message in conversation.messages:
29
  prompt += f"\n{message['from'].strip()}: {message['value'].strip()}"
30
  prompt += f"\n{conversation.bot_label}:"
 
31
  return prompt
32
 
33
+ def _get_response(self, text, custom_generation_params):
34
  api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
35
  api = api.format(self.endpoint, self.namespace)
36
+ print(api)
37
+ parameters = self.generation_params
38
+ if custom_generation_params is not None:
39
+ parameters.update(custom_generation_params)
40
+ payload = {'instances': [text], "parameters": parameters}
41
  resp = requests.post(api, json=payload, timeout=600)
42
  assert resp.status_code == 200, (resp.content, resp.status_code)
43
  return resp.json()["predictions"][0].strip()
models/chatml.py CHANGED
@@ -9,9 +9,7 @@ class ChatML(BaseModel):
9
  [conversation.memory, conversation.prompt]
10
  ).strip()
11
  prompt = f"<|im_start|>system\n{system_message}<|im_end|>"
12
-
13
  for message in conversation.messages:
14
  prompt += f"\n<|im_start|>{message['from']}\n{message['value']}<|im_end|>"
15
  prompt += f"\n<|im_start|>{conversation.bot_label}\n"
16
- print(prompt)
17
  return prompt
 
9
  [conversation.memory, conversation.prompt]
10
  ).strip()
11
  prompt = f"<|im_start|>system\n{system_message}<|im_end|>"
 
12
  for message in conversation.messages:
13
  prompt += f"\n<|im_start|>{message['from']}\n{message['value']}<|im_end|>"
14
  prompt += f"\n<|im_start|>{conversation.bot_label}\n"
 
15
  return prompt
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  firebase-admin==5.2.0
2
- gradio
 
 
 
1
  firebase-admin==5.2.0
2
+ gradio
3
+ gspread
4
+ oauth2client
tabs/arena_battle.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ import random
5
+ from conversation import Conversation
6
+
7
+
8
+ def get_tab_arena_battle(download_bot_config, get_bot_profile, model_mapping, client):
9
+ default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca"
10
+ bot_config = download_bot_config(default_bot_id)
11
+ user_state = gr.State(
12
+ bot_config
13
+ )
14
+ with gr.Row():
15
+ bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
16
+ reload_bot_button = gr.Button("Reload bot")
17
+ bot_profile = gr.HTML(get_bot_profile(bot_config))
18
+ with gr.Accordion("Bot config:", open=False):
19
+ gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}\n")
20
+
21
+ with gr.Row():
22
+ values = list(model_mapping.keys())
23
+ first_message = (None, bot_config["firstMessage"])
24
+ height = 450
25
+ model_a_value, model_b_value = random.sample(values, 2)
26
+ with gr.Column():
27
+ model_a = gr.Textbox(value=model_a_value, label="Model A", interactive=False, visible=False)
28
+ chatbot_a = gr.Chatbot([first_message])
29
+ chatbot_a.style(height=height)
30
+ with gr.Column():
31
+ model_b = gr.Textbox(value=model_b_value, label="Model B", interactive=False, visible=False)
32
+ chatbot_b = gr.Chatbot([first_message])
33
+ chatbot_b.style(height=height)
34
+
35
+ with gr.Row():
36
+ with gr.Column(scale=3):
37
+ msg = gr.Textbox(show_label=False, value="Hi there!", interactive=True)
38
+ with gr.Column(scale=3):
39
+ send = gr.Button("Send")
40
+ with gr.Row():
41
+ vote_a = gr.Button("👈 A is better", interactive=False)
42
+ vote_b = gr.Button("👉 B is better", interactive=False)
43
+ vote_tie = gr.Button("🤝 Tie", interactive=False)
44
+ vote_bad = gr.Button("💩 Both are bad", interactive=False)
45
+ show_models_button = gr.Button("Show models", interactive=False)
46
+ with gr.Row():
47
+ regenerate = gr.Button("Regenerate", interactive=False)
48
+ clear = gr.Button("Restart")
49
+
50
+ with gr.Accordion("Generation parameters for model A", open=False):
51
+ model = model_mapping[model_a.value]
52
+ temperature_model_a = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
53
+ interactive=True, label="Temperature")
54
+ repetition_penalty_model_a = gr.Slider(minimum=0.0, maximum=2.0,
55
+ value=model.generation_params["repetition_penalty"],
56
+ interactive=True, label="Repetition penalty")
57
+ max_new_tokens_model_a = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
58
+ interactive=True, label="Max new tokens")
59
+ top_k_model_a = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
60
+ interactive=True, label="Top-K")
61
+ top_p_model_a = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
62
+ interactive=True, label="Top-P")
63
+
64
+ with gr.Accordion("Generation parameters for model B", open=False):
65
+ model = model_mapping[model_b.value]
66
+ temperature_model_b = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
67
+ interactive=True, label="Temperature")
68
+ repetition_penalty_model_b = gr.Slider(minimum=0.0, maximum=2.0,
69
+ value=model.generation_params["repetition_penalty"],
70
+ interactive=True, label="Repetition penalty")
71
+ max_new_tokens_model_b = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
72
+ interactive=True, label="Max new tokens")
73
+ top_k_model_b = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
74
+ interactive=True, label="Top-K")
75
+ top_p_model_b = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
76
+ interactive=True, label="Top-P")
77
+
78
+ def clear_chat(user_state):
79
+ return "", [(None, user_state["firstMessage"])], [(None, user_state["firstMessage"])]
80
+
81
+ def reload_bot(bot_id):
82
+ bot_config = download_bot_config(bot_id)
83
+ bot_profile = get_bot_profile(bot_config)
84
+ return bot_profile, [(None, bot_config["firstMessage"])], [(None, bot_config["firstMessage"])], bot_config
85
+
86
+ def get_generation_args(model_tag):
87
+ model = model_mapping[model_tag]
88
+ return (
89
+ model.generation_params["temperature"],
90
+ model.generation_params["repetition_penalty"],
91
+ model.generation_params["max_new_tokens"],
92
+ model.generation_params["top_k"],
93
+ model.generation_params["top_p"],
94
+ )
95
+
96
+ def respond(message, chat_history, user_state, model_tag,
97
+ temperature, repetition_penalty, max_new_tokens, top_k, top_p):
98
+ custom_generation_params = {
99
+ 'temperature': temperature,
100
+ 'repetition_penalty': repetition_penalty,
101
+ 'max_new_tokens': max_new_tokens,
102
+ 'top_k': top_k,
103
+ 'top_p': top_p,
104
+ }
105
+ conv = Conversation(user_state)
106
+ conv.set_chat_history(chat_history)
107
+ conv.add_user_message(message)
108
+ model = model_mapping[model_tag]
109
+ bot_message = model.generate_response(conv, custom_generation_params)
110
+ chat_history.append(
111
+ (message, bot_message)
112
+ )
113
+ return "", chat_history
114
+
115
+ def record_vote(user_state, vote,
116
+ chat_history_a, model_tag_a,
117
+ chat_history_b, model_tag_b):
118
+ conv_a = Conversation(user_state)
119
+ conv_a.set_chat_history(chat_history_a)
120
+ conv_b = Conversation(user_state)
121
+ conv_b.set_chat_history(chat_history_b)
122
+ if "A is better" in vote:
123
+ vote_str = "model_a"
124
+ elif "B is better" in vote:
125
+ vote_str = "model_b"
126
+ elif "Tie" in vote:
127
+ vote_str = "tie"
128
+ else:
129
+ vote_str = "tie (bothbad)"
130
+ row = {
131
+ "timestamp": time.time(),
132
+ "bot_id": user_state["bot_id"],
133
+ "vote": vote_str,
134
+ "model_a": model_tag_a,
135
+ "model_b": model_tag_b,
136
+ "is_anonymous": int(True)
137
+ }
138
+ sheet = client.open("Chat Arena").sheet1
139
+ num_rows = len(sheet.get_all_records())
140
+ sheet.insert_row(list(row.values()), index=num_rows + 2)
141
+ return gr.Button.update(interactive=True)
142
+
143
+ def regenerate_response(chat_history, user_state, model_tag,
144
+ temperature, repetition_penalty, max_new_tokens, top_k, top_p):
145
+ if len(chat_history) == 1:
146
+ return "", chat_history
147
+ custom_generation_params = {
148
+ 'temperature': temperature,
149
+ 'repetition_penalty': repetition_penalty,
150
+ 'max_new_tokens': max_new_tokens,
151
+ 'top_k': top_k,
152
+ 'top_p': top_p,
153
+ }
154
+ last_row = chat_history.pop(-1)
155
+ chat_history.append((last_row[0], None))
156
+ model = model_mapping[model_tag]
157
+ conv = Conversation(user_state)
158
+ conv.set_chat_history(chat_history)
159
+ bot_message = model.generate_response(conv, custom_generation_params)
160
+ chat_history[-1] = (last_row[0], bot_message)
161
+ return "", chat_history
162
+
163
+ def disable_voting():
164
+ return [gr.Button.update(interactive=False)] * 4
165
+
166
+ def enable_voting():
167
+ return [gr.Button.update(interactive=True)] * 4
168
+
169
+ def show_models():
170
+ return [gr.Textbox.update(visible=True)] * 2
171
+
172
+ def hide_models():
173
+ model_a_value, model_b_value = random.sample(values, 2)
174
+ return [gr.Textbox.update(visible=False, value=model_a_value),
175
+ gr.Textbox.update(visible=False, value=model_b_value)]
176
+
177
+ def disable_send():
178
+ return [gr.Button.update(interactive=False)] * 3
179
+
180
+ def enable_send():
181
+ return [gr.Button.update(interactive=True), gr.Button.update(interactive=False)]
182
+
183
+ def enable_regenerate():
184
+ return gr.Button.update(interactive=True)
185
+
186
+ for vote in [vote_a, vote_b, vote_tie, vote_bad]:
187
+ vote.click(record_vote,
188
+ [user_state, vote, chatbot_a, model_a, chatbot_b, model_b],
189
+ [show_models_button],
190
+ queue=False)
191
+ vote.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
192
+
193
+ show_models_button.click(show_models, None, [model_a, model_b], queue=False)
194
+ clear.click(hide_models, None, [model_a, model_b], queue=False)
195
+ reload_bot_button.click(hide_models, None, [model_a, model_b], queue=False)
196
+ show_models_button.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
197
+ show_models_button.click(disable_send, None, [send, regenerate, show_models_button], queue=False)
198
+ clear.click(enable_send, None, [send, regenerate], queue=False)
199
+ reload_bot_button.click(enable_send, None, [send, regenerate], queue=False)
200
+
201
+ model_a.change(get_generation_args, [model_a],
202
+ [temperature_model_a, repetition_penalty_model_a, max_new_tokens_model_a, top_k_model_a,
203
+ top_p_model_a], queue=False)
204
+ model_b.change(get_generation_args, [model_b],
205
+ [temperature_model_b, repetition_penalty_model_b, max_new_tokens_model_b, top_k_model_b,
206
+ top_p_model_b], queue=False)
207
+
208
+ clear.click(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
209
+ model_a.change(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
210
+ model_b.change(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
211
+
212
+ # model_a.change(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
213
+ # model_b.change(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
214
+ reload_bot_button.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
215
+ reload_bot_button.click(reload_bot, [bot_id], [bot_profile, chatbot_a, chatbot_b, user_state],
216
+ queue=False)
217
+ send.click(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
218
+ clear.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
219
+ regenerate.click(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
220
+ msg.submit(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
221
+
222
+ send.click(respond,
223
+ [msg, chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
224
+ max_new_tokens_model_a, top_k_model_a, top_p_model_a], [msg, chatbot_a],
225
+ queue=False)
226
+ msg.submit(respond,
227
+ [msg, chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
228
+ max_new_tokens_model_a, top_k_model_a, top_p_model_a], [msg, chatbot_a],
229
+ queue=False)
230
+
231
+ send.click(respond,
232
+ [msg, chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
233
+ max_new_tokens_model_b, top_k_model_b, top_p_model_b], [msg, chatbot_b],
234
+ queue=False)
235
+ msg.submit(respond,
236
+ [msg, chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
237
+ max_new_tokens_model_b, top_k_model_b, top_p_model_b], [msg, chatbot_b],
238
+ queue=False)
239
+
240
+ send.click(enable_regenerate, None, [regenerate], queue=False)
241
+ msg.submit(enable_regenerate, None, [regenerate], queue=False)
242
+
243
+ regenerate.click(regenerate_response,
244
+ [chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
245
+ max_new_tokens_model_a, top_k_model_a,
246
+ top_p_model_a], [msg, chatbot_a], queue=False)
247
+ regenerate.click(regenerate_response,
248
+ [chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
249
+ max_new_tokens_model_b, top_k_model_b,
250
+ top_p_model_b], [msg, chatbot_b], queue=False)
tabs/arena_side_by_side.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ import random
5
+ from conversation import Conversation
6
+
7
+
8
+ def get_tab_arena_side_by_side(download_bot_config, get_bot_profile, model_mapping, client):
9
+ default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca"
10
+ bot_config = download_bot_config(default_bot_id)
11
+ user_state = gr.State(
12
+ bot_config
13
+ )
14
+ with gr.Row():
15
+ bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
16
+ reload_bot_button = gr.Button("Reload bot")
17
+ bot_profile = gr.HTML(get_bot_profile(bot_config))
18
+ with gr.Accordion("Bot config:", open=False):
19
+ gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}\n")
20
+
21
+ with gr.Row():
22
+ values = list(model_mapping.keys())
23
+ first_message = (None, bot_config["firstMessage"])
24
+ height = 450
25
+ model_a_value, model_b_value = random.sample(values, 2)
26
+ with gr.Column():
27
+ model_a = gr.Dropdown(values, value=model_a_value, label="Model A")
28
+ chatbot_a = gr.Chatbot([first_message])
29
+ chatbot_a.style(height=height)
30
+ with gr.Column():
31
+ model_b = gr.Dropdown(values, value=model_b_value, label="Model B")
32
+ chatbot_b = gr.Chatbot([first_message])
33
+ chatbot_b.style(height=height)
34
+
35
+ with gr.Row():
36
+ with gr.Column(scale=3):
37
+ msg = gr.Textbox(show_label=False, value="Hi there!", interactive=True)
38
+ with gr.Column(scale=3):
39
+ send = gr.Button("Send")
40
+ with gr.Row():
41
+ vote_a = gr.Button("👈 A is better", interactive=False)
42
+ vote_b = gr.Button("👉 B is better", interactive=False)
43
+ vote_tie = gr.Button("🤝 Tie", interactive=False)
44
+ vote_bad = gr.Button("💩 Both are bad", interactive=False)
45
+ with gr.Row():
46
+ regenerate = gr.Button("Regenerate", interactive=False)
47
+ clear = gr.Button("Clear")
48
+
49
+ with gr.Accordion("Generation parameters for model A", open=False):
50
+ model = model_mapping[model_a.value]
51
+ temperature_model_a = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
52
+ interactive=True, label="Temperature")
53
+ repetition_penalty_model_a = gr.Slider(minimum=0.0, maximum=2.0,
54
+ value=model.generation_params["repetition_penalty"],
55
+ interactive=True, label="Repetition penalty")
56
+ max_new_tokens_model_a = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
57
+ interactive=True, label="Max new tokens")
58
+ top_k_model_a = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
59
+ interactive=True, label="Top-K")
60
+ top_p_model_a = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
61
+ interactive=True, label="Top-P")
62
+
63
+ with gr.Accordion("Generation parameters for model B", open=False):
64
+ model = model_mapping[model_b.value]
65
+ temperature_model_b = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
66
+ interactive=True, label="Temperature")
67
+ repetition_penalty_model_b = gr.Slider(minimum=0.0, maximum=2.0,
68
+ value=model.generation_params["repetition_penalty"],
69
+ interactive=True, label="Repetition penalty")
70
+ max_new_tokens_model_b = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
71
+ interactive=True, label="Max new tokens")
72
+ top_k_model_b = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
73
+ interactive=True, label="Top-K")
74
+ top_p_model_b = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
75
+ interactive=True, label="Top-P")
76
+
77
+ def clear_chat(user_state):
78
+ return "", [(None, user_state["firstMessage"])], [(None, user_state["firstMessage"])]
79
+
80
+ def reload_bot(bot_id):
81
+ bot_config = download_bot_config(bot_id)
82
+ bot_profile = get_bot_profile(bot_config)
83
+ return bot_profile, [(None, bot_config["firstMessage"])], [(None, bot_config["firstMessage"])], bot_config
84
+
85
+ def get_generation_args(model_tag):
86
+ model = model_mapping[model_tag]
87
+ return (
88
+ model.generation_params["temperature"],
89
+ model.generation_params["repetition_penalty"],
90
+ model.generation_params["max_new_tokens"],
91
+ model.generation_params["top_k"],
92
+ model.generation_params["top_p"],
93
+ )
94
+
95
+ def respond(message, chat_history, user_state, model_tag,
96
+ temperature, repetition_penalty, max_new_tokens, top_k, top_p):
97
+ custom_generation_params = {
98
+ 'temperature': temperature,
99
+ 'repetition_penalty': repetition_penalty,
100
+ 'max_new_tokens': max_new_tokens,
101
+ 'top_k': top_k,
102
+ 'top_p': top_p,
103
+ }
104
+ conv = Conversation(user_state)
105
+ conv.set_chat_history(chat_history)
106
+ conv.add_user_message(message)
107
+ model = model_mapping[model_tag]
108
+ bot_message = model.generate_response(conv, custom_generation_params)
109
+ chat_history.append(
110
+ (message, bot_message)
111
+ )
112
+ return "", chat_history
113
+
114
+ def record_vote(user_state, vote,
115
+ chat_history_a, model_tag_a,
116
+ chat_history_b, model_tag_b):
117
+ if len(chat_history_a) < 2:
118
+ return
119
+ conv_a = Conversation(user_state)
120
+ conv_a.set_chat_history(chat_history_a)
121
+ conv_b = Conversation(user_state)
122
+ conv_b.set_chat_history(chat_history_b)
123
+ if "A is better" in vote:
124
+ vote_str = "model_a"
125
+ elif "B is better" in vote:
126
+ vote_str = "model_b"
127
+ elif "Tie" in vote:
128
+ vote_str = "tie"
129
+ else:
130
+ vote_str = "tie (bothbad)"
131
+ row = {
132
+ "timestamp": time.time(),
133
+ "bot_id": user_state["bot_id"],
134
+ "vote": vote_str,
135
+ "model_a": model_tag_a,
136
+ "model_b": model_tag_b,
137
+ "is_anonymous": int(False)
138
+ }
139
+ sheet = client.open("Chat Arena").sheet1
140
+ num_rows = len(sheet.get_all_records())
141
+ sheet.insert_row(list(row.values()), index=num_rows + 2)
142
+ return
143
+
144
+ def regenerate_response(chat_history, user_state, model_tag,
145
+ temperature, repetition_penalty, max_new_tokens, top_k, top_p):
146
+ custom_generation_params = {
147
+ 'temperature': temperature,
148
+ 'repetition_penalty': repetition_penalty,
149
+ 'max_new_tokens': max_new_tokens,
150
+ 'top_k': top_k,
151
+ 'top_p': top_p,
152
+ }
153
+ last_row = chat_history.pop(-1)
154
+ chat_history.append((last_row[0], None))
155
+ model = model_mapping[model_tag]
156
+ conv = Conversation(user_state)
157
+ conv.set_chat_history(chat_history)
158
+ bot_message = model.generate_response(conv, custom_generation_params)
159
+ chat_history[-1] = (last_row[0], bot_message)
160
+ return "", chat_history
161
+
162
+ def disable_voting():
163
+ return [gr.Button.update(interactive=False)] * 4
164
+
165
+ def enable_voting():
166
+ return [gr.Button.update(interactive=True)] * 4
167
+
168
+ def enable_send():
169
+ return [gr.Button.update(interactive=True), gr.Button.update(interactive=False)]
170
+
171
+ def enable_regenerate():
172
+ return gr.Button.update(interactive=True)
173
+
174
+ for vote in [vote_a, vote_b, vote_tie, vote_bad]:
175
+ vote.click(record_vote,
176
+ [user_state, vote, chatbot_a, model_a, chatbot_b, model_b],
177
+ None,
178
+ queue=False)
179
+ vote.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
180
+
181
+ model_a.change(get_generation_args, [model_a],
182
+ [temperature_model_a, repetition_penalty_model_a, max_new_tokens_model_a, top_k_model_a,
183
+ top_p_model_a], queue=False)
184
+ model_b.change(get_generation_args, [model_b],
185
+ [temperature_model_b, repetition_penalty_model_b, max_new_tokens_model_b, top_k_model_b,
186
+ top_p_model_b], queue=False)
187
+ reload_bot_button.click(reload_bot, [bot_id], [bot_profile, chatbot_a, chatbot_b, user_state],
188
+ queue=False)
189
+ clear.click(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
190
+ model_a.change(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
191
+ model_b.change(clear_chat, [user_state], [msg, chatbot_a, chatbot_b], queue=False)
192
+ clear.click(enable_send, None, [send, regenerate], queue=False)
193
+ reload_bot_button.click(enable_send, None, [send, regenerate], queue=False)
194
+
195
+ model_a.change(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
196
+ model_b.change(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
197
+ reload_bot_button.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
198
+ send.click(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
199
+ clear.click(disable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
200
+ regenerate.click(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
201
+ msg.submit(enable_voting, None, [vote_a, vote_b, vote_tie, vote_bad], queue=False)
202
+
203
+ send.click(respond,
204
+ [msg, chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
205
+ max_new_tokens_model_a, top_k_model_a, top_p_model_a], [msg, chatbot_a],
206
+ queue=False)
207
+ msg.submit(respond,
208
+ [msg, chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
209
+ max_new_tokens_model_a, top_k_model_a, top_p_model_a], [msg, chatbot_a],
210
+ queue=False)
211
+
212
+ send.click(respond,
213
+ [msg, chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
214
+ max_new_tokens_model_b, top_k_model_b, top_p_model_b], [msg, chatbot_b],
215
+ queue=False)
216
+ msg.submit(respond,
217
+ [msg, chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
218
+ max_new_tokens_model_b, top_k_model_b, top_p_model_b], [msg, chatbot_b],
219
+ queue=False)
220
+
221
+ send.click(enable_regenerate, None, [regenerate], queue=False)
222
+ msg.submit(enable_regenerate, None, [regenerate], queue=False)
223
+
224
+ regenerate.click(regenerate_response,
225
+ [chatbot_a, user_state, model_a, temperature_model_a, repetition_penalty_model_a,
226
+ max_new_tokens_model_a, top_k_model_a,
227
+ top_p_model_a], [msg, chatbot_a], queue=False)
228
+ regenerate.click(regenerate_response,
229
+ [chatbot_b, user_state, model_b, temperature_model_b, repetition_penalty_model_b,
230
+ max_new_tokens_model_b, top_k_model_b,
231
+ top_p_model_b], [msg, chatbot_b], queue=False)
tabs/playground.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from conversation import Conversation
3
+
4
+
5
+ def get_tab_playground(download_bot_config, get_bot_profile, model_mapping):
6
+ default_bot_id = "_bot_e21de304-6151-4a04-b025-4c553ae8cbca"
7
+ bot_config = download_bot_config(default_bot_id)
8
+ user_state = gr.State(
9
+ bot_config
10
+ )
11
+ with gr.Row():
12
+ bot_id = gr.Textbox(label="Chai bot ID", value=default_bot_id, interactive=True)
13
+ reload_bot_button = gr.Button("Reload bot")
14
+
15
+ bot_profile = gr.HTML(get_bot_profile(bot_config))
16
+ with gr.Accordion("Bot config:", open=False):
17
+ gr.Markdown(f"# Memory\n{bot_config['memory']}\n# Prompt\n{bot_config['prompt']}\n")
18
+
19
+ first_message = (None, bot_config["firstMessage"])
20
+ chatbot = gr.Chatbot([first_message])
21
+
22
+ msg = gr.Textbox(label="Message", value="Hi there!")
23
+ with gr.Row():
24
+ send = gr.Button("Send")
25
+ regenerate = gr.Button("Regenerate")
26
+ clear = gr.Button("Clear")
27
+ values = list(model_mapping.keys())
28
+ model_tag = gr.Dropdown(values, value=values[0], label="Model version")
29
+ model = model_mapping[model_tag.value]
30
+
31
+ with gr.Accordion("Generation parameters", open=False):
32
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["temperature"],
33
+ interactive=True, label="Temperature")
34
+ repetition_penalty = gr.Slider(minimum=0.0, maximum=2.0,
35
+ value=model.generation_params["repetition_penalty"],
36
+ interactive=True, label="Repetition penalty")
37
+ max_new_tokens = gr.Slider(minimum=1, maximum=512, value=model.generation_params["max_new_tokens"],
38
+ interactive=True, label="Max new tokens")
39
+ top_k = gr.Slider(minimum=1, maximum=100, value=model.generation_params["top_k"],
40
+ interactive=True, label="Top-K")
41
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=model.generation_params["top_p"],
42
+ interactive=True, label="Top-P")
43
+
44
+ def respond(message, chat_history, user_state, model_tag,
45
+ temperature, repetition_penalty, max_new_tokens, top_k, top_p):
46
+ custom_generation_params = {
47
+ 'temperature': temperature,
48
+ 'repetition_penalty': repetition_penalty,
49
+ 'max_new_tokens': max_new_tokens,
50
+ 'top_k': top_k,
51
+ 'top_p': top_p,
52
+ }
53
+ conv = Conversation(user_state)
54
+ conv.set_chat_history(chat_history)
55
+ conv.add_user_message(message)
56
+ model = model_mapping[model_tag]
57
+ bot_message = model.generate_response(conv, custom_generation_params)
58
+ chat_history.append(
59
+ (message, bot_message)
60
+ )
61
+ return "", chat_history
62
+
63
+ def clear_chat(chat_history, user_state):
64
+ chat_history = [(None, user_state["firstMessage"])]
65
+ return chat_history
66
+
67
+ def regenerate_response(chat_history, user_state, model_tag,
68
+ temperature, repetition_penalty, max_new_tokens, top_k, top_p):
69
+ custom_generation_params = {
70
+ 'temperature': temperature,
71
+ 'repetition_penalty': repetition_penalty,
72
+ 'max_new_tokens': max_new_tokens,
73
+ 'top_k': top_k,
74
+ 'top_p': top_p,
75
+ }
76
+ last_row = chat_history.pop(-1)
77
+ chat_history.append((last_row[0], None))
78
+ model = model_mapping[model_tag]
79
+ conv = Conversation(user_state)
80
+ conv.set_chat_history(chat_history)
81
+ bot_message = model.generate_response(conv, custom_generation_params)
82
+ chat_history[-1] = (last_row[0], bot_message)
83
+ return chat_history
84
+
85
+ def reload_bot(bot_id, bot_profile, chat_history):
86
+ bot_config = download_bot_config(bot_id)
87
+ bot_profile = get_bot_profile(bot_config)
88
+ return bot_profile, [(None, bot_config["firstMessage"])], bot_config
89
+
90
+ def get_generation_args(model_tag):
91
+ model = model_mapping[model_tag]
92
+ return (
93
+ model.generation_params["temperature"],
94
+ model.generation_params["repetition_penalty"],
95
+ model.generation_params["max_new_tokens"],
96
+ model.generation_params["top_k"],
97
+ model.generation_params["top_p"],
98
+ )
99
+
100
+ model_tag.change(get_generation_args, [model_tag], [temperature, repetition_penalty, max_new_tokens, top_k,
101
+ top_p], queue=False)
102
+ send.click(respond,
103
+ [msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
104
+ top_p], [msg, chatbot],
105
+ queue=False)
106
+ msg.submit(respond,
107
+ [msg, chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
108
+ top_p], [msg, chatbot],
109
+ queue=False)
110
+ clear.click(clear_chat, [chatbot, user_state], [chatbot], queue=False)
111
+ regenerate.click(regenerate_response,
112
+ [chatbot, user_state, model_tag, temperature, repetition_penalty, max_new_tokens, top_k,
113
+ top_p], [chatbot], queue=False)
114
+ reload_bot_button.click(reload_bot, [bot_id, bot_profile, chatbot], [bot_profile, chatbot, user_state],
115
+ queue=False)