Spaces:
Runtime error
Runtime error
Leaderboard and Unified UI (#61)
Browse files* Update start_app.sh to use gradio instead of python app.py
* fixed action typing error
---------
Co-authored-by: Jasonqi146 <jasonqi146@gmail.com>
- README.md +13 -0
- app.py +65 -314
- data_dir/models_vs_gpt35.jsonl +4 -0
- requirements.txt +14 -14
- sotopia_space/_header.md +4 -0
- sotopia_space/benchmark.py +70 -0
- sotopia_space/chat.py +284 -0
- sotopia_space/constants.py +39 -0
- sotopia_space/utils.py +223 -0
- start_app.sh +1 -1
- ui_constants.py +191 -0
README.md
CHANGED
@@ -11,3 +11,16 @@ license: apache-2.0
|
|
11 |
---
|
12 |
|
13 |
This is a synced repository with a Huggingface Space for the Sotopia project [space](https://huggingface.co/spaces/wdplx/Sotopia-demo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
This is a synced repository with a Huggingface Space for the Sotopia project [space](https://huggingface.co/spaces/wdplx/Sotopia-demo)
|
14 |
+
|
15 |
+
## Getting Started
|
16 |
+
|
17 |
+
```bash
|
18 |
+
conda create -n sotopia-space python=3.11; conda activate sotopia-space
|
19 |
+
python -m pip install -r requirements.txt
|
20 |
+
```
|
21 |
+
|
22 |
+
To run the app, run the following command:
|
23 |
+
|
24 |
+
```bash
|
25 |
+
bash start_app.sh
|
26 |
+
```
|
app.py
CHANGED
@@ -1,332 +1,83 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
import json
|
4 |
from typing import Literal
|
5 |
|
6 |
-
import gradio as gr
|
|
|
|
|
|
|
7 |
|
8 |
-
from utils import Environment, Agent, get_context_prompt, dialogue_history_prompt
|
9 |
-
from functools import cache
|
10 |
-
from sotopia_pi_generate import prepare_model, generate_action
|
11 |
|
12 |
OPENAI_KEY_FILE="./openai_api.key"
|
13 |
if os.path.exists(OPENAI_KEY_FILE):
|
14 |
with open(OPENAI_KEY_FILE, "r") as f:
|
15 |
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
TEMPERATURE = 0.7
|
20 |
-
TOP_P = 1
|
21 |
-
MAX_TOKENS = 1024
|
22 |
|
23 |
-
|
24 |
-
AGENT_PROFILES = "profiles/agent_profiles.jsonl"
|
25 |
-
RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl"
|
26 |
-
|
27 |
-
ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
|
28 |
-
|
29 |
-
MODEL_OPTIONS = [
|
30 |
-
"gpt-3.5-turbo",
|
31 |
-
"gpt-4",
|
32 |
-
"gpt-4-turbo",
|
33 |
-
"cmu-lti/sotopia-pi-mistral-7b-BC_SR",
|
34 |
-
"cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit",
|
35 |
-
"mistralai/Mistral-7B-Instruct-v0.1"
|
36 |
-
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
37 |
-
# "togethercomputer/llama-2-7b-chat",
|
38 |
-
# "togethercomputer/llama-2-70b-chat",
|
39 |
-
# "togethercomputer/mpt-30b-chat",
|
40 |
-
# "together_ai/togethercomputer/llama-2-7b-chat",
|
41 |
-
# "together_ai/togethercomputer/falcon-7b-instruct",
|
42 |
-
]
|
43 |
-
|
44 |
-
@cache
|
45 |
-
def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
|
46 |
-
with open(env_file, 'r') as f:
|
47 |
-
data = [json.loads(line) for line in f.readlines()]
|
48 |
-
|
49 |
-
code_names_count = defaultdict(int)
|
50 |
-
environments = []
|
51 |
-
environment_dict = {}
|
52 |
-
for profile in sorted(data, key=lambda x: x['codename']):
|
53 |
-
env_obj = Environment(profile)
|
54 |
-
if profile['codename'] in code_names_count:
|
55 |
-
environments.append((
|
56 |
-
"{}_{:05d}".format(profile['codename'],
|
57 |
-
code_names_count[profile['codename']]
|
58 |
-
),
|
59 |
-
env_obj._id
|
60 |
-
))
|
61 |
-
else:
|
62 |
-
environments.append((profile['codename'], env_obj._id))
|
63 |
-
environment_dict[env_obj._id] = env_obj
|
64 |
-
code_names_count[profile['codename']] += 1
|
65 |
-
|
66 |
-
with open(agent_file, 'r') as f:
|
67 |
-
data = [json.loads(line) for line in f.readlines()]
|
68 |
-
|
69 |
-
agent_dict = {}
|
70 |
-
for profile in data:
|
71 |
-
agent_obj = Agent(profile)
|
72 |
-
agent_dict[agent_obj._id] = agent_obj
|
73 |
-
|
74 |
-
with open(relationship_file, 'r') as f:
|
75 |
-
data = [json.loads(line) for line in f.readlines()]
|
76 |
-
|
77 |
-
relationship_dict = defaultdict(lambda : defaultdict(list))
|
78 |
-
for profile in data:
|
79 |
-
relationship_dict[profile['relationship']][profile['agent1_id']].append(profile['agent2_id'])
|
80 |
-
relationship_dict[profile['relationship']][profile['agent2_id']].append(profile['agent1_id'])
|
81 |
-
|
82 |
-
return environments, environment_dict, agent_dict, relationship_dict
|
83 |
-
|
84 |
-
|
85 |
-
def introduction():
|
86 |
with gr.Column(scale=2):
|
87 |
-
gr.
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
environment = environment_dict[environment_id]
|
116 |
-
|
117 |
-
user_agents_list = []
|
118 |
-
unique_agent_ids = set()
|
119 |
-
for x, _ in relationship_dict[environment.relationship].items():
|
120 |
-
unique_agent_ids.add(x)
|
121 |
-
|
122 |
-
for agent_id in unique_agent_ids:
|
123 |
-
user_agents_list.append((agent_dict[agent_id].name, agent_id))
|
124 |
-
return gr.Dropdown(choices=user_agents_list, value=user_agents_list[0][1] if user_agents_list else None, label="User Agent Selection")
|
125 |
-
|
126 |
-
def create_bot_agent_dropdown(environment_id, user_agent_id):
|
127 |
-
_, environment_dict, agent_dict, relationship_dict = get_sotopia_profiles()
|
128 |
-
environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]
|
129 |
-
|
130 |
-
bot_agent_list = []
|
131 |
-
for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
|
132 |
-
bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))
|
133 |
-
|
134 |
-
return gr.Dropdown(choices=bot_agent_list, value=bot_agent_list[0][1] if bot_agent_list else None, label="Bot Agent Selection")
|
135 |
-
|
136 |
-
def create_environment_info(environment_dropdown):
|
137 |
-
_, environment_dict, _, _ = get_sotopia_profiles()
|
138 |
-
environment = environment_dict[environment_dropdown]
|
139 |
-
text = environment.scenario
|
140 |
-
return gr.Textbox(label="Scenario", lines=1, value=text)
|
141 |
-
|
142 |
-
def create_user_info(user_agent_dropdown):
|
143 |
-
_, _, agent_dict, _ = get_sotopia_profiles()
|
144 |
-
user_agent = agent_dict[user_agent_dropdown]
|
145 |
-
text = f"{user_agent.background} {user_agent.personality}"
|
146 |
-
return gr.Textbox(label="User Agent Profile", lines=4, value=text)
|
147 |
-
|
148 |
-
def create_bot_info(bot_agent_dropdown):
|
149 |
-
_, _, agent_dict, _ = get_sotopia_profiles()
|
150 |
-
bot_agent = agent_dict[bot_agent_dropdown]
|
151 |
-
text = f"{bot_agent.background} {bot_agent.personality}"
|
152 |
-
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
|
153 |
-
|
154 |
-
def create_user_goal(environment_dropdown):
|
155 |
-
_, environment_dict, _, _ = get_sotopia_profiles()
|
156 |
-
text = environment_dict[environment_dropdown].agent_goals[0]
|
157 |
-
text = text.replace('(', '').replace(')', '')
|
158 |
-
if "<extra_info>" in text:
|
159 |
-
text = text.replace("<extra_info>", "\n\n")
|
160 |
-
text = text.replace("</extra_info>", "\n")
|
161 |
-
if "<strategy_hint>" in text:
|
162 |
-
text = text.replace("<strategy_hint>", "\n\n")
|
163 |
-
text = text.replace("</strategy_hint>", "\n")
|
164 |
-
return gr.Textbox(label="User Agent Goal", lines=4, value=text)
|
165 |
-
|
166 |
-
def create_bot_goal(environment_dropdown):
|
167 |
-
_, environment_dict, _, _ = get_sotopia_profiles()
|
168 |
-
text = environment_dict[environment_dropdown].agent_goals[1]
|
169 |
-
text = text.replace('(', '').replace(')', '')
|
170 |
-
if "<extra_info>" in text:
|
171 |
-
text = text.replace("<extra_info>", "\n\n")
|
172 |
-
text = text.replace("</extra_info>", "\n")
|
173 |
-
if "<strategy_hint>" in text:
|
174 |
-
text = text.replace("<strategy_hint>", "\n\n")
|
175 |
-
text = text.replace("</strategy_hint>", "\n")
|
176 |
-
return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
|
177 |
-
|
178 |
-
def sotopia_info_accordion(accordion_visible=True):
|
179 |
-
environments, _, _, _ = get_sotopia_profiles()
|
180 |
-
|
181 |
-
with gr.Accordion("Create your sotopia space!", open=accordion_visible):
|
182 |
-
with gr.Row():
|
183 |
-
environment_dropdown = gr.Dropdown(
|
184 |
-
choices=environments,
|
185 |
-
label="Scenario Selection",
|
186 |
-
value=environments[0][1] if environments else None,
|
187 |
-
interactive=True,
|
188 |
-
)
|
189 |
-
model_name_dropdown = gr.Dropdown(
|
190 |
-
choices=MODEL_OPTIONS,
|
191 |
-
value=DEFAULT_MODEL_SELECTION,
|
192 |
-
interactive=True,
|
193 |
-
label="Model Selection"
|
194 |
-
)
|
195 |
-
|
196 |
-
with gr.Row():
|
197 |
-
user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
|
198 |
-
bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)
|
199 |
-
|
200 |
-
with gr.Accordion("Check your social task!", open=accordion_visible):
|
201 |
-
|
202 |
-
scenario_info_display = create_environment_info(environment_dropdown.value)
|
203 |
-
|
204 |
-
with gr.Row():
|
205 |
-
bot_goal_display = create_bot_goal(environment_dropdown.value)
|
206 |
-
user_goal_display = create_user_goal(environment_dropdown.value)
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
with gr.Row():
|
211 |
-
bot_agent_info_display = create_bot_info(bot_agent_dropdown.value)
|
212 |
-
user_agent_info_display = create_user_info(user_agent_dropdown.value)
|
213 |
-
|
214 |
-
# Update user dropdown when scenario changes
|
215 |
-
environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
|
216 |
-
# Update bot dropdown when user or scenario changes
|
217 |
-
user_agent_dropdown.change(fn=create_bot_agent_dropdown, inputs=[environment_dropdown, user_agent_dropdown], outputs=[bot_agent_dropdown])
|
218 |
-
# Update scenario information when scenario changes
|
219 |
-
environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
|
220 |
-
# Update user agent profile when user changes
|
221 |
-
user_agent_dropdown.change(fn=create_user_info, inputs=[user_agent_dropdown], outputs=[user_agent_info_display])
|
222 |
-
# Update bot agent profile when bot changes
|
223 |
-
bot_agent_dropdown.change(fn=create_bot_info, inputs=[bot_agent_dropdown], outputs=[bot_agent_info_display])
|
224 |
-
# Update user goal when scenario changes
|
225 |
-
environment_dropdown.change(fn=create_user_goal, inputs=[environment_dropdown], outputs=[user_goal_display])
|
226 |
-
# Update bot goal when scenario changes
|
227 |
-
environment_dropdown.change(fn=create_bot_goal, inputs=[environment_dropdown], outputs=[bot_goal_display])
|
228 |
-
|
229 |
-
return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown
|
230 |
-
|
231 |
-
def instructions_accordion(instructions, according_visible=False):
|
232 |
-
with gr.Accordion("Instructions", open=False, visible=according_visible):
|
233 |
-
instructions = gr.Textbox(
|
234 |
-
lines=10,
|
235 |
-
value=instructions,
|
236 |
-
interactive=False,
|
237 |
-
placeholder="Instructions",
|
238 |
-
show_label=False,
|
239 |
-
max_lines=10,
|
240 |
-
visible=False,
|
241 |
-
)
|
242 |
-
return instructions
|
243 |
-
|
244 |
-
|
245 |
-
def chat_tab():
|
246 |
-
# history are input output pairs
|
247 |
-
_, environment_dict, agent_dict, _ = get_sotopia_profiles()
|
248 |
-
def run_chat(
|
249 |
-
message,
|
250 |
-
history,
|
251 |
-
environment_selection,
|
252 |
-
user_agent_dropdown,
|
253 |
-
bot_agent_dropdown,
|
254 |
-
model_selection:str
|
255 |
-
):
|
256 |
-
environment = environment_dict[environment_selection]
|
257 |
-
user_agent = agent_dict[user_agent_dropdown]
|
258 |
-
bot_agent = agent_dict[bot_agent_dropdown]
|
259 |
-
|
260 |
-
context = get_context_prompt(bot_agent, user_agent, environment)
|
261 |
-
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
262 |
-
prompt_history = f"{context}{dialogue_history}"
|
263 |
-
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
264 |
-
return agent_action.to_natural_language()
|
265 |
-
|
266 |
-
with gr.Column():
|
267 |
-
with gr.Blocks():
|
268 |
-
model_name_dropdown, scenario_dropdown, user_agent_dropdown, bot_agent_dropdown = sotopia_info_accordion()
|
269 |
-
|
270 |
-
with gr.Column():
|
271 |
-
with gr.Accordion("Start the conversation to achieve your goal!", open=True):
|
272 |
-
gr.ChatInterface(
|
273 |
-
fn=run_chat,
|
274 |
-
chatbot=gr.Chatbot(
|
275 |
-
height=620,
|
276 |
-
render=False,
|
277 |
-
show_label=False,
|
278 |
-
rtl=False,
|
279 |
-
avatar_images=(
|
280 |
-
"images/profile1.jpg",
|
281 |
-
"images/profile2.jpg",
|
282 |
-
),
|
283 |
-
),
|
284 |
-
textbox=gr.Textbox(
|
285 |
-
placeholder="Write your message here...",
|
286 |
-
render=False,
|
287 |
-
scale=7,
|
288 |
-
rtl=False,
|
289 |
-
),
|
290 |
-
additional_inputs=[
|
291 |
-
scenario_dropdown,
|
292 |
-
user_agent_dropdown,
|
293 |
-
bot_agent_dropdown,
|
294 |
-
model_name_dropdown,
|
295 |
-
],
|
296 |
-
submit_btn="Send",
|
297 |
-
stop_btn="Stop",
|
298 |
-
retry_btn="🔄 Retry",
|
299 |
-
undo_btn="↩️ Delete",
|
300 |
-
clear_btn="🗑️ Clear",
|
301 |
-
)
|
302 |
-
|
303 |
-
|
304 |
-
def main():
|
305 |
-
with gr.Blocks(
|
306 |
-
css="""#chat_container {height: 820px; width: 1000px; margin-left: auto; margin-right: auto;}
|
307 |
-
#chatbot {height: 600px; overflow: auto;}
|
308 |
-
#create_container {height: 750px; margin-left: 0px; margin-right: 0px;}
|
309 |
-
#tokenizer_renderer span {white-space: pre-wrap}
|
310 |
-
"""
|
311 |
-
) as demo:
|
312 |
-
with gr.Row():
|
313 |
-
introduction()
|
314 |
-
with gr.Row():
|
315 |
-
chat_tab()
|
316 |
-
|
317 |
-
return demo
|
318 |
-
|
319 |
-
|
320 |
-
def start_demo():
|
321 |
-
demo = main()
|
322 |
-
if DEPLOYED:
|
323 |
-
demo.queue(api_open=False).launch(show_api=False)
|
324 |
-
else:
|
325 |
-
demo.queue()
|
326 |
-
demo.launch(share=False, server_name="0.0.0.0")
|
327 |
|
328 |
|
329 |
if __name__ == "__main__":
|
|
|
|
|
|
|
330 |
get_sotopia_profiles()
|
331 |
# prepare_model(DEFAULT_MODEL_SELECTION)
|
332 |
-
|
|
|
1 |
import os
|
2 |
+
import argparse
|
|
|
3 |
from typing import Literal
|
4 |
|
5 |
+
import gradio as gr # type: ignore
|
6 |
+
from sotopia_space.chat import chat_introduction, chat_tab, get_sotopia_profiles
|
7 |
+
from sotopia_space import benchmark
|
8 |
+
from ui_constants import CITATION_TEXT, BANNER
|
9 |
|
|
|
|
|
|
|
10 |
|
11 |
OPENAI_KEY_FILE="./openai_api.key"
|
12 |
if os.path.exists(OPENAI_KEY_FILE):
|
13 |
with open(OPENAI_KEY_FILE, "r") as f:
|
14 |
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
15 |
|
16 |
+
with open("./sotopia_space/_header.md", "r") as f:
|
17 |
+
HEADER_MD = f.read()
|
|
|
|
|
|
|
18 |
|
19 |
+
def navigation_bar():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
with gr.Column(scale=2):
|
21 |
+
toggle_dark = gr.Button(value="Toggle Dark")
|
22 |
+
toggle_dark.click(
|
23 |
+
None,
|
24 |
+
js="""
|
25 |
+
() => {
|
26 |
+
if (document.body.classList.contains('dark')) {
|
27 |
+
document.body.classList.remove('dark');
|
28 |
+
document.querySelector('gradio-app').style.backgroundColor = 'var(--color-background-primary-light)';
|
29 |
+
} else {
|
30 |
+
document.body.classList.add('dark');
|
31 |
+
document.querySelector('gradio-app').style.backgroundColor = 'var(--color-background-primary-dark)';
|
32 |
+
}
|
33 |
+
}
|
34 |
+
""",
|
35 |
+
)
|
36 |
+
|
37 |
+
with gr.Blocks(
|
38 |
+
css="""#chat_container {height: 820px; width: 1000px; margin-left: auto; margin-right: auto;}
|
39 |
+
#chatbot {height: 600px; overflow: auto;}
|
40 |
+
#create_container {height: 750px; margin-left: 0px; margin-right: 0px;}
|
41 |
+
#tokenizer_renderer span {white-space: pre-wrap}
|
42 |
+
""",
|
43 |
+
theme="gradio/monochrome",
|
44 |
+
) as demo:
|
45 |
+
# with gr.Row():
|
46 |
+
# navigation_bar()
|
47 |
+
gr.Image(
|
48 |
+
"images/banner.png", elem_id="banner-image", show_label=False
|
49 |
)
|
50 |
+
gr.Markdown(HEADER_MD, elem_classes="markdown-text")
|
51 |
+
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
52 |
+
with gr.TabItem("🏅 Leaderboard", elem_id="benchmark-tab-table", id=0):
|
53 |
+
benchmark.benchmark_table()
|
54 |
+
with gr.TabItem("💬 Chat", elem_id="chat-tab-interface", id=1):
|
55 |
+
with gr.Row():
|
56 |
+
chat_introduction()
|
57 |
+
with gr.Row():
|
58 |
+
chat_tab()
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Accordion("📙 Citation", open=False, elem_classes="accordion-label"):
|
61 |
+
gr.Textbox(
|
62 |
+
value=CITATION_TEXT,
|
63 |
+
lines=7,
|
64 |
+
label="Copy the BibTeX snippet to cite this source",
|
65 |
+
elem_id="citation-button",
|
66 |
+
show_copy_button=True)
|
67 |
+
|
68 |
+
# def start_demo():
|
69 |
+
# demo = main()
|
70 |
+
# if DEPLOYED:
|
71 |
+
# demo.queue(api_open=False).launch(show_api=False)
|
72 |
+
# else:
|
73 |
+
# demo.queue()
|
74 |
+
# demo.launch(share=False, server_name="0.0.0.0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
|
77 |
if __name__ == "__main__":
|
78 |
+
parser = argparse.ArgumentParser()
|
79 |
+
parser.add_argument("--result_file", help="Path to results table", default="data_dir/models_vs_gpt35.jsonl")
|
80 |
+
#benchmark.original_df = pd.read_json(args.result_file, lines=True)
|
81 |
get_sotopia_profiles()
|
82 |
# prepare_model(DEFAULT_MODEL_SELECTION)
|
83 |
+
demo.launch()
|
data_dir/models_vs_gpt35.jsonl
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"model_name": "GPT-4", "SOC [-10, 0]": -0.07, "SEC [-10, 0]": -0.14, "FIN [-5, 5]": 0.81, "REL [-5, 5]": 1.94, "KNO [0, 10]": 3.73, "GOAL [0, 10]": 7.62, "BEL [0, 10]": 9.28}
|
2 |
+
{"model_name": "GPT-3.5", "SOC [-10, 0]": -0.08, "SEC [-10, 0]": -0.08, "FIN [-5, 5]": 0.46, "REL [-5, 5]": 1.23, "KNO [0, 10]": 3.4, "GOAL [0, 10]": 6.45, "BEL [0, 10]": 9.15}
|
3 |
+
{"model_name": "Llama-2", "SOC [-10, 0]": -0.11, "SEC [-10, 0]": -0.14, "FIN [-5, 5]": 0.4, "REL [-5, 5]": 0.91, "KNO [0, 10]": 3.11, "GOAL [0, 10]": 5.38, "BEL [0, 10]": 8.1}
|
4 |
+
{"model_name": "MPT", "SOC [-10, 0]": -0.09, "SEC [-10, 0]": -0.07, "FIN [-5, 5]": 0.28, "REL [-5, 5]": 0.58, "KNO [0, 10]": 2.11, "GOAL [0, 10]": 4.1, "BEL [0, 10]": 6.17}
|
requirements.txt
CHANGED
@@ -8,7 +8,7 @@ annotated-types==0.6.0
|
|
8 |
anyio==3.7.1
|
9 |
attrs==23.2.0
|
10 |
beartype==0.14.1
|
11 |
-
bitsandbytes==0.
|
12 |
certifi==2024.2.2
|
13 |
cffi==1.16.0
|
14 |
charset-normalizer==3.3.2
|
@@ -68,18 +68,18 @@ mypy-extensions==1.0.0
|
|
68 |
names==0.3.0
|
69 |
networkx==3.3
|
70 |
numpy==1.26.4
|
71 |
-
nvidia-cublas-cu12==12.1.3.1
|
72 |
-
nvidia-cuda-cupti-cu12==12.1.105
|
73 |
-
nvidia-cuda-nvrtc-cu12==12.1.105
|
74 |
-
nvidia-cuda-runtime-cu12==12.1.105
|
75 |
-
nvidia-cudnn-cu12==8.9.2.26
|
76 |
-
nvidia-cufft-cu12==11.0.2.54
|
77 |
-
nvidia-curand-cu12==10.3.2.106
|
78 |
-
nvidia-cusolver-cu12==11.4.5.107
|
79 |
-
nvidia-cusparse-cu12==12.1.0.106
|
80 |
-
nvidia-nccl-cu12==2.19.3
|
81 |
-
nvidia-nvjitlink-cu12==12.4.127
|
82 |
-
nvidia-nvtx-cu12==12.1.105
|
83 |
openai==1.22.0
|
84 |
orjson==3.10.1
|
85 |
packaging==23.2
|
@@ -129,7 +129,7 @@ toolz==0.12.1
|
|
129 |
torch==2.2.2
|
130 |
tqdm==4.66.2
|
131 |
transformers==4.40.0
|
132 |
-
triton==2.2.0
|
133 |
typer==0.12.3
|
134 |
types-cffi==1.16.0.20240331
|
135 |
types-pyOpenSSL==24.0.0.20240417
|
|
|
8 |
anyio==3.7.1
|
9 |
attrs==23.2.0
|
10 |
beartype==0.14.1
|
11 |
+
bitsandbytes==0.42.0
|
12 |
certifi==2024.2.2
|
13 |
cffi==1.16.0
|
14 |
charset-normalizer==3.3.2
|
|
|
68 |
names==0.3.0
|
69 |
networkx==3.3
|
70 |
numpy==1.26.4
|
71 |
+
# nvidia-cublas-cu12==12.1.3.1
|
72 |
+
# nvidia-cuda-cupti-cu12==12.1.105
|
73 |
+
# nvidia-cuda-nvrtc-cu12==12.1.105
|
74 |
+
# nvidia-cuda-runtime-cu12==12.1.105
|
75 |
+
# nvidia-cudnn-cu12==8.9.2.26
|
76 |
+
# nvidia-cufft-cu12==11.0.2.54
|
77 |
+
# nvidia-curand-cu12==10.3.2.106
|
78 |
+
# nvidia-cusolver-cu12==11.4.5.107
|
79 |
+
# nvidia-cusparse-cu12==12.1.0.106
|
80 |
+
# nvidia-nccl-cu12==2.19.3
|
81 |
+
# nvidia-nvjitlink-cu12==12.4.127
|
82 |
+
# nvidia-nvtx-cu12==12.1.105
|
83 |
openai==1.22.0
|
84 |
orjson==3.10.1
|
85 |
packaging==23.2
|
|
|
129 |
torch==2.2.2
|
130 |
tqdm==4.66.2
|
131 |
transformers==4.40.0
|
132 |
+
# triton==2.2.0
|
133 |
typer==0.12.3
|
134 |
types-cffi==1.16.0.20240331
|
135 |
types-pyOpenSSL==24.0.0.20240417
|
sotopia_space/_header.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<br/>
|
2 |
+
|
3 |
+
# Sotopia Space: A Huggingface Space for the Sotopia projects
|
4 |
+
[⚙️ GitHub](https://github.com/sotopia-lab) | [🤗 HuggingFace](https://huggingface.co/collections/cmu-lti/sotopia-65f312c1bd04a8c4a9225e5b) | [💬 Discussions](https://github.com/orgs/sotopia-lab/discussions)
|
sotopia_space/benchmark.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr # type: ignore
|
2 |
+
import pandas as pd
|
3 |
+
from sotopia_space.constants import MODEL_OPTIONS
|
4 |
+
from sotopia_space.utils import estimated_win_rate, make_clickable_model, styled_error, styled_warning, styled_message,apply_length_penalty
|
5 |
+
|
6 |
+
LP_MODE = "v2"
|
7 |
+
original_df, ablation_df = None, None
|
8 |
+
LP_original_dfs = {}
|
9 |
+
DEFAULT_LP = 0.5
|
10 |
+
|
11 |
+
available_models = [] # to be filled in later
|
12 |
+
original_df, ablation_df = None, None
|
13 |
+
|
14 |
+
def slider_change_main(length_penalty):
|
15 |
+
global original_df, ablation_df, LP_MODE
|
16 |
+
adjusted_df = apply_length_penalty(original_df, ablation_df, length_penalty, mode=LP_MODE, LP_original_dfs=LP_original_dfs)
|
17 |
+
adjusted_df = adjusted_df[["Model", "Overall Elo", "Task-Avg Elo", "# battles", "Length"]]
|
18 |
+
adjusted_df = adjusted_df.sort_values(by="Overall Elo", ascending=False)
|
19 |
+
# adjusted_df = add_winrates(adjusted_df, LP=length_penalty)
|
20 |
+
# adjusted_df = adjusted_df.drop(columns=["Length"])
|
21 |
+
adjusted_df.insert(0, "Rank", range(1, 1 + len(adjusted_df)))
|
22 |
+
return adjusted_df
|
23 |
+
|
24 |
+
def slider_change_full(length_penalty, show_winrate):
|
25 |
+
global original_df, ablation_df, LP_MODE
|
26 |
+
adjusted_df = apply_length_penalty(original_df, ablation_df, length_penalty, mode=LP_MODE, LP_original_dfs=LP_original_dfs)
|
27 |
+
# sort the model by the "Task-Avg Elo" column
|
28 |
+
adjusted_df = adjusted_df.sort_values(by="Overall Elo", ascending=False)
|
29 |
+
adjusted_df.drop(columns=["Overall Elo", "Task-Avg Elo", "# battles", "Length"], inplace=True)
|
30 |
+
if show_winrate == "none":
|
31 |
+
adjusted_df.insert(0, "Rank", range(1, 1 + len(adjusted_df)))
|
32 |
+
return adjusted_df
|
33 |
+
elif show_winrate == "gpt-3.5":
|
34 |
+
adjusted_df = add_winrates_tasks(adjusted_df, ref="gpt-3.5", LP=length_penalty)
|
35 |
+
elif show_winrate == "gpt-4":
|
36 |
+
adjusted_df = add_winrates_tasks(adjusted_df, ref="gpt-4", LP=length_penalty)
|
37 |
+
adjusted_df.insert(0, "Rank", range(1, 1 + len(adjusted_df)))
|
38 |
+
return adjusted_df
|
39 |
+
|
40 |
+
def benchmark_table():
|
41 |
+
global original_df, ablation_df
|
42 |
+
global LP_original_dfs, LP_MODE
|
43 |
+
|
44 |
+
gr.Markdown(f"**Version**: sotopia (v1.01; 2024.04.22) | **# Examples**: 7200 | **# Models**: {len(MODEL_OPTIONS)} | **# Comparisons**: x", elem_classes="markdown-text")
|
45 |
+
|
46 |
+
with gr.TabItem("Vs GPT-3.5", elem_id="od-benchmark-tab-table-ablation", id=0, elem_classes="subtab"):
|
47 |
+
# original_df, ablation_df = skip_empty_original_df, skip_empty_ablation_df
|
48 |
+
original_df = pd.read_json('data_dir/models_vs_gpt35.jsonl', lines=True)
|
49 |
+
default_main_df = apply_length_penalty(original_df, ablation_df, length_penalty=DEFAULT_LP, mode=LP_MODE, LP_original_dfs=LP_original_dfs)
|
50 |
+
default_main_df = default_main_df.sort_values(by="GOAL [0, 10]", ascending=False)
|
51 |
+
# add a Rank column to the first columnn (starting from 1)
|
52 |
+
default_main_df.insert(0, "Rank", range(1, 1 + len(default_main_df)))
|
53 |
+
with gr.Row():
|
54 |
+
with gr.Column(scale=4):
|
55 |
+
gr.Markdown("**Vs GPT3.5**: The interlocutors are compared against GPT-3.5, the baseline model.")
|
56 |
+
with gr.Column(scale=1):
|
57 |
+
length_penlty_slider = gr.Slider(minimum=0.1, maximum=1, step=0.1, value=DEFAULT_LP, label="Length Penalty", elem_id="length-penalty-slider")
|
58 |
+
# checkbox_skip_empty = gr.Checkbox(label="Skip empty results", value=False, elem_id="skip-empty-checkbox", scale=2)
|
59 |
+
TYPES = ["number", "markdown", "number"]
|
60 |
+
leaderboard_table = gr.components.Dataframe(
|
61 |
+
value=default_main_df,
|
62 |
+
datatype=TYPES,
|
63 |
+
# max_rows=None,
|
64 |
+
height=1000,
|
65 |
+
elem_id="leaderboard-table",
|
66 |
+
interactive=False,
|
67 |
+
visible=True,
|
68 |
+
min_width=60,
|
69 |
+
)
|
70 |
+
#length_penlty_slider.change(fn=slider_change_main, inputs=[length_penlty_slider], outputs=[leaderboard_table])
|
sotopia_space/chat.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr # type: ignore
|
3 |
+
# Functions for creating the chat interface
|
4 |
+
from functools import cache
|
5 |
+
from typing import Literal
|
6 |
+
import json
|
7 |
+
from collections import defaultdict
|
8 |
+
from utils import Environment, Agent, get_context_prompt, dialogue_history_prompt
|
9 |
+
from sotopia_pi_generate import prepare_model, generate_action
|
10 |
+
from sotopia_space.constants import MODEL_OPTIONS
|
11 |
+
|
12 |
+
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
13 |
+
DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo"
|
14 |
+
TEMPERATURE = 0.7
|
15 |
+
TOP_P = 1
|
16 |
+
MAX_TOKENS = 1024
|
17 |
+
|
18 |
+
ENVIRONMENT_PROFILES = "profiles/environment_profiles.jsonl"
|
19 |
+
AGENT_PROFILES = "profiles/agent_profiles.jsonl"
|
20 |
+
RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl"
|
21 |
+
|
22 |
+
Action = Literal['none', 'action', 'non-verbal communication', 'speak', 'leave']
|
23 |
+
ACTION_TYPES: list[Action] = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
@cache
|
28 |
+
def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
|
29 |
+
with open(env_file, 'r') as f:
|
30 |
+
data = [json.loads(line) for line in f.readlines()]
|
31 |
+
|
32 |
+
code_names_count = defaultdict(int)
|
33 |
+
environments = []
|
34 |
+
environment_dict = {}
|
35 |
+
for profile in sorted(data, key=lambda x: x['codename']):
|
36 |
+
env_obj = Environment(profile)
|
37 |
+
if profile['codename'] in code_names_count:
|
38 |
+
environments.append((
|
39 |
+
"{}_{:05d}".format(profile['codename'],
|
40 |
+
code_names_count[profile['codename']]
|
41 |
+
),
|
42 |
+
env_obj._id
|
43 |
+
))
|
44 |
+
else:
|
45 |
+
environments.append((profile['codename'], env_obj._id))
|
46 |
+
environment_dict[env_obj._id] = env_obj
|
47 |
+
code_names_count[profile['codename']] += 1
|
48 |
+
|
49 |
+
with open(agent_file, 'r') as f:
|
50 |
+
data = [json.loads(line) for line in f.readlines()]
|
51 |
+
|
52 |
+
agent_dict = {}
|
53 |
+
for profile in data:
|
54 |
+
agent_obj = Agent(profile)
|
55 |
+
agent_dict[agent_obj._id] = agent_obj
|
56 |
+
|
57 |
+
with open(relationship_file, 'r') as f:
|
58 |
+
data = [json.loads(line) for line in f.readlines()]
|
59 |
+
|
60 |
+
relationship_dict = defaultdict(lambda : defaultdict(list))
|
61 |
+
for profile in data:
|
62 |
+
relationship_dict[profile['relationship']][profile['agent1_id']].append(profile['agent2_id'])
|
63 |
+
relationship_dict[profile['relationship']][profile['agent2_id']].append(profile['agent1_id'])
|
64 |
+
|
65 |
+
return environments, environment_dict, agent_dict, relationship_dict
|
66 |
+
|
67 |
+
def chat_introduction():
|
68 |
+
with gr.Column(scale=2):
|
69 |
+
gr.Image(
|
70 |
+
"images/sotopia.jpg", elem_id="banner-image", show_label=False
|
71 |
+
)
|
72 |
+
with gr.Column(scale=5):
|
73 |
+
gr.Markdown(
|
74 |
+
"""# Sotopia Space
|
75 |
+
**Chat with different social agent models including [sotopia-pi](https://github.com/sotopia-lab/sotopia-pi), GPT and so on in sotopia space!**
|
76 |
+
|
77 |
+
➡️️ **Intended Use**: Sotopia space is intended to showcase the social intelligence ability of different social agents in interesting social scenarios.
|
78 |
+
|
79 |
+
✨ **Guidance**:
|
80 |
+
|
81 |
+
Step (1) Select a social scenario that interests you in "Scenario Selection"
|
82 |
+
|
83 |
+
Step (2) Select a social agent you want to chat with in "Model Selection"
|
84 |
+
|
85 |
+
Step (3) Select which character you and your social agent will play in the scenario in "User Agent Selection" and "Bot Agent Selection"
|
86 |
+
|
87 |
+
Step (4) Negotiate/debate/cooperate with the social agent to see whether your goal or their social goal can be achieved.
|
88 |
+
|
89 |
+
⚠️ **Limitations**: The social agent can and will produce factually incorrect information, hallucinating facts and potentially offensive actions. It can produce problematic outputs, especially if prompted to do so.
|
90 |
+
|
91 |
+
🗄️ **Disclaimer**: User prompts and generated replies from the model may be collected solely for the purpose of pure academic research. By using this demo, users implicitly agree to these terms.
|
92 |
+
"""
|
93 |
+
)
|
94 |
+
# with gr.Column(scale=1):
|
95 |
+
# toggle_dark = gr.Button(value="Toggle Dark")
|
96 |
+
|
97 |
+
def create_user_agent_dropdown(environment_id):
|
98 |
+
_, environment_dict, agent_dict, relationship_dict = get_sotopia_profiles()
|
99 |
+
environment = environment_dict[environment_id]
|
100 |
+
|
101 |
+
user_agents_list = []
|
102 |
+
unique_agent_ids = set()
|
103 |
+
for x, _ in relationship_dict[environment.relationship].items():
|
104 |
+
unique_agent_ids.add(x)
|
105 |
+
|
106 |
+
for agent_id in unique_agent_ids:
|
107 |
+
user_agents_list.append((agent_dict[agent_id].name, agent_id))
|
108 |
+
return gr.Dropdown(choices=user_agents_list, value=user_agents_list[0][1] if user_agents_list else None, label="User Agent Selection")
|
109 |
+
|
110 |
+
def create_bot_agent_dropdown(environment_id, user_agent_id):
|
111 |
+
_, environment_dict, agent_dict, relationship_dict = get_sotopia_profiles()
|
112 |
+
environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]
|
113 |
+
|
114 |
+
bot_agent_list = []
|
115 |
+
for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
|
116 |
+
bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))
|
117 |
+
|
118 |
+
return gr.Dropdown(choices=bot_agent_list, value=bot_agent_list[0][1] if bot_agent_list else None, label="Bot Agent Selection")
|
119 |
+
|
120 |
+
def create_environment_info(environment_dropdown):
|
121 |
+
_, environment_dict, _, _ = get_sotopia_profiles()
|
122 |
+
environment = environment_dict[environment_dropdown]
|
123 |
+
text = environment.scenario
|
124 |
+
return gr.Textbox(label="Scenario", lines=1, value=text)
|
125 |
+
|
126 |
+
def create_user_info(user_agent_dropdown):
|
127 |
+
_, _, agent_dict, _ = get_sotopia_profiles()
|
128 |
+
user_agent = agent_dict[user_agent_dropdown]
|
129 |
+
text = f"{user_agent.background} {user_agent.personality}"
|
130 |
+
return gr.Textbox(label="User Agent Profile", lines=4, value=text)
|
131 |
+
|
132 |
+
def create_bot_info(bot_agent_dropdown):
|
133 |
+
_, _, agent_dict, _ = get_sotopia_profiles()
|
134 |
+
bot_agent = agent_dict[bot_agent_dropdown]
|
135 |
+
text = f"{bot_agent.background} {bot_agent.personality}"
|
136 |
+
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
|
137 |
+
|
138 |
+
def create_user_goal(environment_dropdown):
|
139 |
+
_, environment_dict, _, _ = get_sotopia_profiles()
|
140 |
+
text = environment_dict[environment_dropdown].agent_goals[0]
|
141 |
+
text = text.replace('(', '').replace(')', '')
|
142 |
+
if "<extra_info>" in text:
|
143 |
+
text = text.replace("<extra_info>", "\n\n")
|
144 |
+
text = text.replace("</extra_info>", "\n")
|
145 |
+
if "<strategy_hint>" in text:
|
146 |
+
text = text.replace("<strategy_hint>", "\n\n")
|
147 |
+
text = text.replace("</strategy_hint>", "\n")
|
148 |
+
return gr.Textbox(label="User Agent Goal", lines=4, value=text)
|
149 |
+
|
150 |
+
def create_bot_goal(environment_dropdown):
|
151 |
+
_, environment_dict, _, _ = get_sotopia_profiles()
|
152 |
+
text = environment_dict[environment_dropdown].agent_goals[1]
|
153 |
+
text = text.replace('(', '').replace(')', '')
|
154 |
+
if "<extra_info>" in text:
|
155 |
+
text = text.replace("<extra_info>", "\n\n")
|
156 |
+
text = text.replace("</extra_info>", "\n")
|
157 |
+
if "<strategy_hint>" in text:
|
158 |
+
text = text.replace("<strategy_hint>", "\n\n")
|
159 |
+
text = text.replace("</strategy_hint>", "\n")
|
160 |
+
return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
|
161 |
+
|
162 |
+
def sotopia_info_accordion(accordion_visible=True):
|
163 |
+
environments, _, _, _ = get_sotopia_profiles()
|
164 |
+
|
165 |
+
with gr.Accordion("Create your sotopia space!", open=accordion_visible):
|
166 |
+
with gr.Row():
|
167 |
+
environment_dropdown = gr.Dropdown(
|
168 |
+
choices=environments,
|
169 |
+
label="Scenario Selection",
|
170 |
+
value=environments[0][1] if environments else None,
|
171 |
+
interactive=True,
|
172 |
+
)
|
173 |
+
model_name_dropdown = gr.Dropdown(
|
174 |
+
choices=MODEL_OPTIONS,
|
175 |
+
value=DEFAULT_MODEL_SELECTION,
|
176 |
+
interactive=True,
|
177 |
+
label="Model Selection"
|
178 |
+
)
|
179 |
+
|
180 |
+
with gr.Row():
|
181 |
+
user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
|
182 |
+
bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)
|
183 |
+
|
184 |
+
with gr.Accordion("Check your social task!", open=accordion_visible):
|
185 |
+
|
186 |
+
scenario_info_display = create_environment_info(environment_dropdown.value)
|
187 |
+
|
188 |
+
with gr.Row():
|
189 |
+
bot_goal_display = create_bot_goal(environment_dropdown.value)
|
190 |
+
user_goal_display = create_user_goal(environment_dropdown.value)
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
with gr.Row():
|
195 |
+
bot_agent_info_display = create_bot_info(bot_agent_dropdown.value)
|
196 |
+
user_agent_info_display = create_user_info(user_agent_dropdown.value)
|
197 |
+
|
198 |
+
# Update user dropdown when scenario changes
|
199 |
+
environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
|
200 |
+
# Update bot dropdown when user or scenario changes
|
201 |
+
user_agent_dropdown.change(fn=create_bot_agent_dropdown, inputs=[environment_dropdown, user_agent_dropdown], outputs=[bot_agent_dropdown])
|
202 |
+
# Update scenario information when scenario changes
|
203 |
+
environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
|
204 |
+
# Update user agent profile when user changes
|
205 |
+
user_agent_dropdown.change(fn=create_user_info, inputs=[user_agent_dropdown], outputs=[user_agent_info_display])
|
206 |
+
# Update bot agent profile when bot changes
|
207 |
+
bot_agent_dropdown.change(fn=create_bot_info, inputs=[bot_agent_dropdown], outputs=[bot_agent_info_display])
|
208 |
+
# Update user goal when scenario changes
|
209 |
+
environment_dropdown.change(fn=create_user_goal, inputs=[environment_dropdown], outputs=[user_goal_display])
|
210 |
+
# Update bot goal when scenario changes
|
211 |
+
environment_dropdown.change(fn=create_bot_goal, inputs=[environment_dropdown], outputs=[bot_goal_display])
|
212 |
+
|
213 |
+
return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown
|
214 |
+
|
215 |
+
def instructions_accordion(instructions, according_visible=False):
|
216 |
+
with gr.Accordion("Instructions", open=False, visible=according_visible):
|
217 |
+
instructions = gr.Textbox(
|
218 |
+
lines=10,
|
219 |
+
value=instructions,
|
220 |
+
interactive=False,
|
221 |
+
placeholder="Instructions",
|
222 |
+
show_label=False,
|
223 |
+
max_lines=10,
|
224 |
+
visible=False,
|
225 |
+
)
|
226 |
+
return instructions
|
227 |
+
|
228 |
+
def chat_tab():
|
229 |
+
# history are input output pairs
|
230 |
+
_, environment_dict, agent_dict, _ = get_sotopia_profiles()
|
231 |
+
def run_chat(
|
232 |
+
message,
|
233 |
+
history,
|
234 |
+
environment_selection,
|
235 |
+
user_agent_dropdown,
|
236 |
+
bot_agent_dropdown,
|
237 |
+
model_selection:str
|
238 |
+
):
|
239 |
+
environment = environment_dict[environment_selection]
|
240 |
+
user_agent = agent_dict[user_agent_dropdown]
|
241 |
+
bot_agent = agent_dict[bot_agent_dropdown]
|
242 |
+
|
243 |
+
context = get_context_prompt(bot_agent, user_agent, environment)
|
244 |
+
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
245 |
+
prompt_history = f"{context}{dialogue_history}"
|
246 |
+
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
247 |
+
return agent_action.to_natural_language()
|
248 |
+
|
249 |
+
with gr.Column():
|
250 |
+
with gr.Blocks():
|
251 |
+
model_name_dropdown, scenario_dropdown, user_agent_dropdown, bot_agent_dropdown = sotopia_info_accordion()
|
252 |
+
|
253 |
+
with gr.Column():
|
254 |
+
with gr.Accordion("Start the conversation to achieve your goal!", open=True):
|
255 |
+
gr.ChatInterface(
|
256 |
+
fn=run_chat,
|
257 |
+
chatbot=gr.Chatbot(
|
258 |
+
height=620,
|
259 |
+
render=False,
|
260 |
+
show_label=False,
|
261 |
+
rtl=False,
|
262 |
+
avatar_images=(
|
263 |
+
"images/profile1.jpg",
|
264 |
+
"images/profile2.jpg",
|
265 |
+
),
|
266 |
+
),
|
267 |
+
textbox=gr.Textbox(
|
268 |
+
placeholder="Write your message here...",
|
269 |
+
render=False,
|
270 |
+
scale=7,
|
271 |
+
rtl=False,
|
272 |
+
),
|
273 |
+
additional_inputs=[
|
274 |
+
scenario_dropdown,
|
275 |
+
user_agent_dropdown,
|
276 |
+
bot_agent_dropdown,
|
277 |
+
model_name_dropdown,
|
278 |
+
],
|
279 |
+
submit_btn="Send",
|
280 |
+
stop_btn="Stop",
|
281 |
+
retry_btn="🔄 Retry",
|
282 |
+
undo_btn="↩️ Delete",
|
283 |
+
clear_btn="🗑️ Clear",
|
284 |
+
)
|
sotopia_space/constants.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_OPTIONS = [
|
2 |
+
"gpt-3.5-turbo",
|
3 |
+
"gpt-4",
|
4 |
+
"gpt-4-turbo",
|
5 |
+
"cmu-lti/sotopia-pi-mistral-7b-BC_SR",
|
6 |
+
"cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit",
|
7 |
+
"mistralai/Mistral-7B-Instruct-v0.1"
|
8 |
+
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
9 |
+
# "togethercomputer/llama-2-7b-chat",
|
10 |
+
# "togethercomputer/llama-2-70b-chat",
|
11 |
+
# "togethercomputer/mpt-30b-chat",
|
12 |
+
# "together_ai/togethercomputer/llama-2-7b-chat",
|
13 |
+
# "together_ai/togethercomputer/falcon-7b-instruct",
|
14 |
+
]
|
15 |
+
|
16 |
+
MODEL_INFO = {
|
17 |
+
"Llama-2-13b-chat-hf.nosp": {"pretty_name": "Llama-2-13B-chat", "hf_model_id": "meta-llama/Llama-2-13b-chat-hf"},
|
18 |
+
"Llama-2-70b-chat-hf.nosp": {"pretty_name": "Llama-2-70B-chat", "hf_model_id": "meta-llama/Llama-2-70b-chat-hf"},
|
19 |
+
"Llama-2-7b-chat-hf.nosp": {"pretty_name": "Llama-2-7B-chat", "hf_model_id": "meta-llama/Llama-2-7b-chat-hf"},
|
20 |
+
"Llama-2-7b-chat-hf": {"pretty_name": "Llama-2-7B-chat (+sys prmpt)", "hf_model_id": "meta-llama/Llama-2-7b-chat-hf"},
|
21 |
+
"Mistral-7B-Instruct-v0.1": {"pretty_name": "Mistral-7B-Instruct", "hf_model_id": "mistralai/Mistral-7B-Instruct-v0.1"},
|
22 |
+
"Mistral-7B-Instruct-v0.2": {"pretty_name": "Mistral-7B-Instruct (v0.2)", "hf_model_id": "mistralai/Mistral-7B-Instruct-v0.2"},
|
23 |
+
"Mixtral-8x7B-Instruct-v0.1": {"pretty_name": "Mixtral-8x7B-Instruct", "hf_model_id": "mistralai/Mixtral-8x7B-Instruct-v0.1"},
|
24 |
+
"Nous-Hermes-2-Mixtral-8x7B-DPO": {"pretty_name": "Nous-Hermes-2-Mixtral-8x7B-DPO", "hf_model_id": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO"},
|
25 |
+
"Yi-34B-Chat": {"pretty_name": "Yi-34B-Chat", "hf_model_id": "01-ai/Yi-34B"},
|
26 |
+
"gemini-1.0-pro": {"pretty_name": "gemini-1.0-pro", "hf_model_id": "https://blog.google/technology/ai/google-gemini-ai/"},
|
27 |
+
"gemma-7b-it": {"pretty_name": "Gemma-7B-it", "hf_model_id": "google/gemma-7b"},
|
28 |
+
"gpt-3.5-turbo-0125": {"pretty_name": "gpt-3.5-turbo-0125", "hf_model_id": "https://platform.openai.com/"},
|
29 |
+
"gpt-4-0125-preview": {"pretty_name": "gpt-4-0125-preview", "hf_model_id": "https://platform.openai.com/"},
|
30 |
+
"tulu-2-dpo-70b": {"pretty_name": "Tulu-2-dpo-70b", "hf_model_id": "cmu-lti/tulu-2-dpo-70b"},
|
31 |
+
"vicuna-13b-v1.5": {"pretty_name": "Vicuna-13b-v1.5", "hf_model_id": "lmsys/vicuna-13b-v1.5"},
|
32 |
+
"zephyr-7b-beta": {"pretty_name": "Zephyr-7b-beta", "hf_model_id": "HuggingFaceH4/zephyr-7b-beta"},
|
33 |
+
"mistral-large-2402": {"pretty_name": "Mistral-Large", "hf_model_id": "https://mistral.ai/news/mistral-large/"},
|
34 |
+
"claude-3-opus-20240229": {"pretty_name": "Claude 3 Opus", "hf_model_id": "https://www.anthropic.com/claude"},
|
35 |
+
"claude-3-sonnet-20240229": {"pretty_name": "Claude 3 Sonnet", "hf_model_id": "https://www.anthropic.com/claude"},
|
36 |
+
"zephyr-7b-gemma-v0.1": {"pretty_name": "Zephyr-7b-Gemma", "hf_model_id": "HuggingFaceH4/zephyr-7b-gemma-v0.1"},
|
37 |
+
"Starling-LM-7B-beta": {"pretty_name": "StarlingLM-7B-beta", "hf_model_id": "Nexusflow/Starling-LM-7B-beta"},
|
38 |
+
"dbrx-instruct": {"pretty_name": "DBRX Instruct", "hf_model_id": "databricks/dbrx-instruct"}
|
39 |
+
}
|
sotopia_space/utils.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset, Dataset
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
from datasets import load_dataset
|
5 |
+
from datasets.utils.logging import disable_progress_bar # type: ignore
|
6 |
+
from ui_constants import column_names, all_task_types
|
7 |
+
import random
|
8 |
+
disable_progress_bar()
|
9 |
+
import math
|
10 |
+
from sotopia_space.constants import MODEL_INFO
|
11 |
+
|
12 |
+
id_to_data = None
|
13 |
+
model_len_info = None
|
14 |
+
|
15 |
+
|
16 |
+
def make_clickable_model(model_name):
|
17 |
+
global MODEL_INFO
|
18 |
+
if model_name in MODEL_INFO:
|
19 |
+
if MODEL_INFO[model_name]["hf_model_id"].startswith("http"):
|
20 |
+
link = MODEL_INFO[model_name]["hf_model_id"]
|
21 |
+
return f'🔒 <a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{MODEL_INFO[model_name]["pretty_name"]}</a>'
|
22 |
+
else:
|
23 |
+
link = f"https://huggingface.co/{MODEL_INFO[model_name]['hf_model_id']}"
|
24 |
+
return f'🔥 <a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{MODEL_INFO[model_name]["pretty_name"]}</a>'
|
25 |
+
else:
|
26 |
+
return model_name
|
27 |
+
|
28 |
+
|
29 |
+
def styled_error(error):
|
30 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
|
31 |
+
|
32 |
+
def styled_warning(warn):
|
33 |
+
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
|
34 |
+
|
35 |
+
def styled_message(message):
|
36 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
|
37 |
+
|
38 |
+
|
39 |
+
def estimated_win_rate(elo_a, elo_b, LP=0):
|
40 |
+
"""
|
41 |
+
Calculate the estimated win rate for player A against player B using their Elo ratings.
|
42 |
+
:param elo_a: Elo rating of player A
|
43 |
+
:param elo_b: Elo rating of player B
|
44 |
+
:return: Estimated win rate for player A
|
45 |
+
"""
|
46 |
+
exponent = (elo_b - elo_a)*(10**LP) / 400
|
47 |
+
probability_a_wins = 1 / (1 + 10 ** exponent)
|
48 |
+
return (1-probability_a_wins)*100
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
# Formats the columns
|
53 |
+
def formatter(x):
|
54 |
+
if type(x) is str:
|
55 |
+
x = x
|
56 |
+
else:
|
57 |
+
x = round(x, 1)
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
def add_winrates(current_df, LP=0):
|
62 |
+
df = current_df.copy()
|
63 |
+
elo_column = "Task-Avg Elo"
|
64 |
+
|
65 |
+
# Correct way to filter the DataFrame and get the Elo rating for "gpt-4-0125-preview"
|
66 |
+
model_a_elo = df[df["Model"].str.contains("gpt-4")][elo_column].iloc[0]
|
67 |
+
|
68 |
+
# Correct way to filter the DataFrame and get the Elo rating for "gpt-3.5-turbo-0125"
|
69 |
+
model_b_elo = df[df["Model"].str.contains("gpt-3.5")][elo_column].iloc[0]
|
70 |
+
|
71 |
+
|
72 |
+
# Calculate the win rate of "gpt-4-0125-preview" against all models
|
73 |
+
df['Win% vs GPT-4'] = df[elo_column].apply(lambda x: estimated_win_rate(model_a_elo, x, LP=LP)).apply(formatter)
|
74 |
+
df['Win% vs GPT-3.5T'] = df[elo_column].apply(lambda x: estimated_win_rate(model_b_elo, x, LP=LP)).apply(formatter)
|
75 |
+
# apply the formatter for the two new columns
|
76 |
+
cols = list(df.columns)
|
77 |
+
cols.remove("# battles"); cols.append("# battles")
|
78 |
+
cols.remove("Length"); cols.append("Length")
|
79 |
+
df = df[cols]
|
80 |
+
return df
|
81 |
+
|
82 |
+
def add_winrates_tasks(current_df, ref="gpt-4", LP=0):
|
83 |
+
new_df = current_df.copy()
|
84 |
+
for t in all_task_types:
|
85 |
+
column = column_names[t]
|
86 |
+
model_a_elo = current_df[current_df["Model"].str.contains(ref)][column].iloc[0]
|
87 |
+
new_df[column] = current_df[column].apply(lambda x: estimated_win_rate(model_a_elo, x, LP=LP)).apply(formatter)
|
88 |
+
return new_df
|
89 |
+
|
90 |
+
|
91 |
+
def post_processing(df, model_len_info):
|
92 |
+
if model_len_info:
|
93 |
+
df["Length"] = df["model name "].apply(lambda x: model_len_info[x]["avg_len"])
|
94 |
+
|
95 |
+
for col in df.columns:
|
96 |
+
if col == "model name ":
|
97 |
+
df[col] = df[col].apply(lambda x: x.replace(x, make_clickable_model(x)))
|
98 |
+
else:
|
99 |
+
df[col] = df[col].apply(formatter) # For numerical values
|
100 |
+
df.rename(columns=column_names, inplace=True)
|
101 |
+
df.sort_values(by="Task-Avg Elo", inplace=True, ascending=False)
|
102 |
+
# put the "Overall Elo" and "Task-Avg Elo" column to the front
|
103 |
+
# add the length info
|
104 |
+
df = df[["Model", "Task-Avg Elo"] + [col for col in df.columns if col not in ["Model", "Task-Avg Elo"]]]
|
105 |
+
return df
|
106 |
+
|
107 |
+
def apply_length_penalty(original_df, ablation_df, length_penalty=0.2, mode='v1', LP_original_dfs=None):
|
108 |
+
"""
|
109 |
+
Temporarily disable the length penalty feature
|
110 |
+
if mode == 'v2' and LP_original_dfs is not None:
|
111 |
+
L = f"{length_penalty:.1f}"
|
112 |
+
return LP_original_dfs[L]
|
113 |
+
original_df = original_df.copy()
|
114 |
+
ablation_df = ablation_df.copy()
|
115 |
+
# replace all values in original_df with the values as z = x - y * length_penalty where y is from ablation_df at the same row and column
|
116 |
+
# except for the "Model" column and the "# battles" column
|
117 |
+
# do not assume the order of the rows are the same in both dataframes
|
118 |
+
for i, row in original_df.iterrows():
|
119 |
+
for col in original_df.columns:
|
120 |
+
if col == "Model" or col == "# battles" or col == "Length":
|
121 |
+
continue
|
122 |
+
# assert that the model names are the same in both dataframes
|
123 |
+
assert original_df.at[i, "Model"] == ablation_df[ablation_df["Model"] == row["Model"]]["Model"].values[0]
|
124 |
+
original_df[col] = original_df[col].astype(float)
|
125 |
+
if mode == "v1":
|
126 |
+
original_df.at[i, col] = original_df.at[i, col] - ablation_df[ablation_df["Model"] == row["Model"]][col].values[0] * length_penalty
|
127 |
+
elif mode == "v1.1":
|
128 |
+
diff = original_df.at[i, col] - ablation_df[ablation_df["Model"] == row["Model"]][col].values[0]
|
129 |
+
original_df.at[i, col] = original_df.at[i, col] * (1-length_penalty) + diff*length_penalty
|
130 |
+
# post_processing
|
131 |
+
original_df = post_processing(original_df, model_len_info=None)
|
132 |
+
"""
|
133 |
+
return original_df
|
134 |
+
|
135 |
+
def load_benchdata():
|
136 |
+
print("Loading sotopia data...")
|
137 |
+
bench_data = load_dataset("cmu-lti/sotopia", split="test")
|
138 |
+
return bench_data
|
139 |
+
|
140 |
+
def load_benchdata_dict():
|
141 |
+
print("Loading sotopia data....")
|
142 |
+
bench_data = load_dataset("cmu-lti/sotopia", data_files="sotopia_episodes_v1_hf.jsonl")['train']
|
143 |
+
id_to_data = {}
|
144 |
+
for item in bench_data:
|
145 |
+
id_to_data[item["session_id"]] = item
|
146 |
+
return id_to_data
|
147 |
+
|
148 |
+
def load_eval_results():
|
149 |
+
print("Loading sotopia Evaluation data...")
|
150 |
+
eval_results = load_dataset("WildEval/sotopia-Evaluation", "all", split="train")
|
151 |
+
return eval_results
|
152 |
+
|
153 |
+
def load_infer_results(model_name):
|
154 |
+
print(f"Loading sotopia Results for {model_name}...")
|
155 |
+
infer_results = load_dataset("WildEval/sotopia-Results", model_name, split="train")
|
156 |
+
return infer_results
|
157 |
+
|
158 |
+
def sample_an_eval_result(eval_results, model_list=[], tag_list=[]):
|
159 |
+
global id_to_data
|
160 |
+
eval_results = list(eval_results)
|
161 |
+
random.shuffle(eval_results)
|
162 |
+
for eval_item in eval_results:
|
163 |
+
# print(json.dumps(eval_item, indent=2))
|
164 |
+
# print(f"## Session ID: {eval_item['session_id']}")
|
165 |
+
# eval_item["eval_id"]
|
166 |
+
assignment = eval_item['assignment']
|
167 |
+
model_1, model_2 = eval_item['model_1'], eval_item['model_2']
|
168 |
+
model_A = model_1 if assignment['A'] == model_1 else model_2
|
169 |
+
model_B = model_2 if assignment['B'] == model_2 else model_1
|
170 |
+
if len(model_list) >= 2:
|
171 |
+
if model_A not in model_list or model_B not in model_list:
|
172 |
+
continue
|
173 |
+
elif len(model_list) == 1:
|
174 |
+
if model_A != model_list[0] and model_B != model_list[0]:
|
175 |
+
continue
|
176 |
+
else:
|
177 |
+
pass
|
178 |
+
if tag_list:
|
179 |
+
if set(tag_list).isdisjoint(set(eval_item['tags'])):
|
180 |
+
continue
|
181 |
+
winner = eval_item['winner']
|
182 |
+
# print(f"## Model A: {model_A} | Model B: {model_B} | Winner: {winner}")
|
183 |
+
task_type = eval_item['tags'][0] # primary task type
|
184 |
+
chat_history = eval_item['history']
|
185 |
+
last_query = eval_item['last_query']
|
186 |
+
# print(f"## Task Type: {task_type}")
|
187 |
+
# print(f"## Chat History: {chat_history}")
|
188 |
+
# print(f"## Last Query --> USER: {last_query}")
|
189 |
+
|
190 |
+
model_A_output = eval_item['model_1_output'] if model_1 == model_A else eval_item['model_2_output']
|
191 |
+
model_B_output = eval_item['model_2_output'] if model_2 == model_B else eval_item['model_1_output']
|
192 |
+
|
193 |
+
if len(model_A_output.strip()) == 0 or len(model_B_output.strip()) == 0:
|
194 |
+
continue
|
195 |
+
|
196 |
+
conversation_input = id_to_data[eval_item['session_id']]["conversation_input"]
|
197 |
+
# print(f"\n\n\n## Model A ({model_A}) Output ##\n{model_A_output}")
|
198 |
+
# print(f"\n\n\n## Model B ({model_B}) Output ##\n{model_B_output}")
|
199 |
+
|
200 |
+
# print(f"\n\n\n## Winner ##\n{winner}")
|
201 |
+
# print(f"\n\n\n## GPT-4 Judgement ##\n{eval_item['parsed_result']}")
|
202 |
+
|
203 |
+
result_dict = {
|
204 |
+
"session_id": eval_item['session_id'],
|
205 |
+
"model_A": model_A,
|
206 |
+
"model_B": model_B,
|
207 |
+
"winner": winner,
|
208 |
+
"intent": id_to_data[eval_item['session_id']]["intent"],
|
209 |
+
"task_type": task_type,
|
210 |
+
"all_tags": eval_item['tags'],
|
211 |
+
"chat_history": chat_history,
|
212 |
+
"last_query": last_query,
|
213 |
+
"conversation_input": conversation_input,
|
214 |
+
"model_A_output": model_A_output,
|
215 |
+
"model_B_output": model_B_output,
|
216 |
+
"reason": eval_item['parsed_result']["reason"],
|
217 |
+
"choice": eval_item['parsed_result']["choice"],
|
218 |
+
"checklist": id_to_data[eval_item['session_id']]["checklist"],
|
219 |
+
}
|
220 |
+
break
|
221 |
+
return result_dict
|
222 |
+
|
223 |
+
#id_to_data = load_benchdata_dict()
|
start_app.sh
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
export OPENAI_API_KEY=$(cat openai_api.key)
|
2 |
export HF_TOKEN=$(cat hf_token.key)
|
3 |
|
4 |
-
|
|
|
1 |
export OPENAI_API_KEY=$(cat openai_api.key)
|
2 |
export HF_TOKEN=$(cat hf_token.key)
|
3 |
|
4 |
+
gradio app.py
|
ui_constants.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
DEFAULT_LP = 0.5
|
4 |
+
|
5 |
+
banner_url = "https://github.com/sotopia-lab/sotopia-website/blob/main/public/bg_xl.png" # the same repo here.
|
6 |
+
BANNER = f'<div style="display: flex; justify-content: flex-start;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 300px; max-width: 800px;"> </div>'
|
7 |
+
|
8 |
+
TITLE = "<html> <head> <style> h1 {text-align: center;} </style> </head> <body> <h1> 🦁 AI2 sotopia Leaderboard </b> </body> </html>"
|
9 |
+
|
10 |
+
WINRATE_HEATMAP = "<div><img src='https://github.com/WildEval/sotopia-Leaderboard/blob/main/gradio/pairwise_win_fractions.png?raw=true' style='width:100%;'></div>"
|
11 |
+
|
12 |
+
CITATION_TEXT = """@inproceedings{
|
13 |
+
zhou2024sotopia,
|
14 |
+
title={{SOTOPIA}: Interactive Evaluation for Social Intelligence in Language Agents},
|
15 |
+
author={Xuhui Zhou and Hao Zhu and Leena Mathur and Ruohong Zhang and Haofei Yu and Zhengyang Qi and Louis-Philippe Morency and Yonatan Bisk and Daniel Fried and Graham Neubig and Maarten Sap},
|
16 |
+
booktitle={The Twelfth International Conference on Learning Representations},
|
17 |
+
year={2024},
|
18 |
+
url={https://openreview.net/forum?id=mM7VurbA4r}
|
19 |
+
}
|
20 |
+
"""
|
21 |
+
|
22 |
+
|
23 |
+
column_names = {
|
24 |
+
"model name ": "Model",
|
25 |
+
"elo overall": "Overall Elo",
|
26 |
+
'Information seeking': 'InfoSek',
|
27 |
+
'Creative Writing': 'CrtWrt',
|
28 |
+
'Coding & Debugging': 'Code',
|
29 |
+
'Reasoning': 'Reason',
|
30 |
+
'Editing': 'Edit',
|
31 |
+
'Math': 'Math',
|
32 |
+
'Planning': 'Plan',
|
33 |
+
'Brainstorming': 'Brnstrm',
|
34 |
+
'Role playing': 'RolPly',
|
35 |
+
'Advice seeking': 'AdvSek',
|
36 |
+
'Data Analysis': 'DataAna',
|
37 |
+
'Others': 'Misc',
|
38 |
+
"average": "Task-Avg Elo",
|
39 |
+
}
|
40 |
+
|
41 |
+
all_task_types = [
|
42 |
+
'Information seeking',
|
43 |
+
'Creative Writing',
|
44 |
+
'Coding & Debugging',
|
45 |
+
'Reasoning',
|
46 |
+
'Editing',
|
47 |
+
'Math',
|
48 |
+
'Planning',
|
49 |
+
'Brainstorming',
|
50 |
+
'Role playing',
|
51 |
+
'Advice seeking',
|
52 |
+
'Data Analysis',
|
53 |
+
'Others'
|
54 |
+
]
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
js_light = """
|
59 |
+
function refresh() {
|
60 |
+
const url = new URL(window.location);
|
61 |
+
if (url.searchParams.get('__theme') !== 'light') {
|
62 |
+
url.searchParams.set('__theme', 'light');
|
63 |
+
window.location.href = url.href;
|
64 |
+
}
|
65 |
+
}
|
66 |
+
"""
|
67 |
+
|
68 |
+
js_code = """
|
69 |
+
function scroll_top() {
|
70 |
+
console.log("Hello from Gradio!");
|
71 |
+
const bubbles = document.querySelectorAll('.bubble-wrap');
|
72 |
+
bubbles.forEach((bubble, index) => {
|
73 |
+
setTimeout(() => {
|
74 |
+
bubble.scrollTop = 0;
|
75 |
+
}, index * 100); // Delay of 100ms between each iteration
|
76 |
+
});
|
77 |
+
}
|
78 |
+
"""
|
79 |
+
|
80 |
+
|
81 |
+
TASK_TYPE_STR = "**Tasks**: Info seeking (**InfoSek**), Creative Writing (**CrtWrt**), Coding&Debugging (**Code**), Reasoning (**Reason**), Editing (**Edit**), **Math**, Planning (**Plan**), Brainstorming (**Brnstrm**), Role playing (**RolPly**), Advice seeking (**AdvSek**), Data Analysis (**DataAna**)"
|
82 |
+
|
83 |
+
css = """
|
84 |
+
code {
|
85 |
+
font-size: large;
|
86 |
+
}
|
87 |
+
footer {visibility: hidden}
|
88 |
+
.top-left-LP{
|
89 |
+
margin-top: 6px;
|
90 |
+
margin-left: 5px;
|
91 |
+
}
|
92 |
+
.markdown-text{font-size: 14pt}
|
93 |
+
.markdown-text-small{font-size: 13pt}
|
94 |
+
.markdown-text-tiny{font-size: 12pt}
|
95 |
+
.markdown-text-tiny-red{
|
96 |
+
font-size: 12pt;
|
97 |
+
color: red;
|
98 |
+
background-color: yellow;
|
99 |
+
font-color: red;
|
100 |
+
font-weight: bold;
|
101 |
+
}
|
102 |
+
th {
|
103 |
+
text-align: center;
|
104 |
+
font-size: 17px; /* Adjust the font size as needed */
|
105 |
+
}
|
106 |
+
td {
|
107 |
+
font-size: 15px; /* Adjust the font size as needed */
|
108 |
+
text-align: center;
|
109 |
+
}
|
110 |
+
.sample_button{
|
111 |
+
border: 1px solid #000000;
|
112 |
+
border-radius: 5px;
|
113 |
+
padding: 5px;
|
114 |
+
font-size: 15pt;
|
115 |
+
font-weight: bold;
|
116 |
+
margin: 5px;
|
117 |
+
}
|
118 |
+
.chat-common{
|
119 |
+
height: auto;
|
120 |
+
max-height: 400px;
|
121 |
+
min-height: 100px;
|
122 |
+
}
|
123 |
+
.chat-specific{
|
124 |
+
height: auto;
|
125 |
+
max-height: 600px;
|
126 |
+
min-height: 200px;
|
127 |
+
}
|
128 |
+
#od-benchmark-tab-table-button{
|
129 |
+
font-size: 15pt;
|
130 |
+
font-weight: bold;
|
131 |
+
}
|
132 |
+
.btn_boderline{
|
133 |
+
border: 1px solid #000000;
|
134 |
+
border-radius: 5px;
|
135 |
+
padding: 5px;
|
136 |
+
margin: 5px;
|
137 |
+
font-size: 15pt;
|
138 |
+
font-weight: bold;
|
139 |
+
}
|
140 |
+
.btn_boderline_next{
|
141 |
+
border: 0.1px solid #000000;
|
142 |
+
border-radius: 5px;
|
143 |
+
padding: 5px;
|
144 |
+
margin: 5px;
|
145 |
+
font-size: 15pt;
|
146 |
+
font-weight: bold;
|
147 |
+
}
|
148 |
+
.btn_boderline_gray{
|
149 |
+
border: 0.5px solid gray;
|
150 |
+
border-radius: 5px;
|
151 |
+
padding: 5px;
|
152 |
+
margin: 5px;
|
153 |
+
font-size: 15pt;
|
154 |
+
font-weight: italic;
|
155 |
+
}
|
156 |
+
.btn_boderline_selected{
|
157 |
+
border: 2px solid purple;
|
158 |
+
background-color: #f2f2f2;
|
159 |
+
border-radius: 5px;
|
160 |
+
padding: 5px;
|
161 |
+
margin: 5px;
|
162 |
+
font-size: 15pt;
|
163 |
+
font-weight: bold;
|
164 |
+
}
|
165 |
+
.accordion-label button span{
|
166 |
+
font-size: 14pt;
|
167 |
+
font-weight: bold;
|
168 |
+
}
|
169 |
+
#select-models span{
|
170 |
+
font-size: 10pt;
|
171 |
+
}
|
172 |
+
#select-tasks span{
|
173 |
+
font-size: 10pt;
|
174 |
+
}
|
175 |
+
.markdown-text-details{
|
176 |
+
margin: 10px;
|
177 |
+
padding: 10px;
|
178 |
+
}
|
179 |
+
button.selected[role="tab"][aria-selected="true"] {
|
180 |
+
font-size: 18px; /* or any other size you prefer */
|
181 |
+
font-weight: bold;
|
182 |
+
}
|
183 |
+
#od-benchmark-tab-table-ablation-button {
|
184 |
+
font-size: larger; /* Adjust the font size as needed */
|
185 |
+
}
|
186 |
+
.plotly-plot{
|
187 |
+
height: auto;
|
188 |
+
max-height: 600px;
|
189 |
+
min-height: 600px;
|
190 |
+
}
|
191 |
+
"""
|