Spaces:
Runtime error
Runtime error
import random | |
import gradio as gr | |
import sys | |
import traceback | |
import pandas as pd | |
import gradio as gr | |
import json | |
import yaml | |
# from tqdm import tqdm | |
from scripts.UBAR_code.interaction import UBAR_interact | |
from scripts.user_model_code.interaction import multiwoz_interact | |
from scripts.UBAR_code.interaction.UBAR_interact import bcolors | |
# Initialise agents | |
UBAR_checkpoint_path = "epoch50_trloss0.59_gpt2" | |
user_model_checkpoint_path = "MultiWOZ-full_checkpoint_step340k" | |
sys_model = self_play_sys_model = UBAR_interact.UbarSystemModel( | |
"UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml" | |
) | |
user_model = self_play_user_model = multiwoz_interact.NeuralAgent( | |
"user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml" | |
) | |
# Get goals | |
n_goals = 100 | |
goals_path = "data/raw/UBAR/multi-woz/data.json" | |
print("Loading goals...") | |
goals = multiwoz_interact.read_multiWOZ_20_goals(goals_path, n_goals) | |
# Initialise agent with first goal (can be incrememnted by user) for user simulator tab | |
curr_goal_idx = random.randint(0, n_goals - 1) | |
current_goal = goals[curr_goal_idx] | |
user_model.init_session(ini_goal=current_goal) | |
# Do the same initialisation but for the self-play tab | |
curr_sp_goal_idx = random.randint(0, n_goals - 1) | |
current_sp_goal = goals[curr_sp_goal_idx] | |
self_play_user_model.init_session(ini_goal=current_sp_goal) | |
# Get the responses for each agent and track conversation history | |
ds_history = [] | |
us_history = [] | |
self_play_history = [] | |
def reset_ds_state(): | |
ds_history.clear() | |
sys_model.init_session() | |
return ds_history | |
def reset_us_state(): | |
us_history.clear() | |
user_model.init_session(ini_goal=current_goal) | |
return us_history | |
def reset_self_play_state(): | |
self_play_history.clear() | |
self_play_sys_model.init_session() | |
self_play_user_model.init_session(ini_goal=current_sp_goal) | |
return self_play_history | |
def change_goal(): | |
global curr_goal_idx | |
global current_goal | |
curr_goal_idx = random.randint(0, n_goals - 1) | |
current_goal = goals[curr_goal_idx] | |
us_history = reset_us_state() | |
current_goal_yaml = yaml.dump(current_goal, default_flow_style=False) | |
return current_goal_yaml, us_history | |
def change_sp_goal(): | |
global curr_sp_goal_idx | |
global current_sp_goal | |
curr_sp_goal_idx = random.randint(0, n_goals - 1) | |
current_sp_goal = goals[curr_sp_goal_idx] | |
self_play_history = reset_self_play_state() | |
current_sp_goal_yaml = yaml.dump(current_sp_goal, default_flow_style=False) | |
return current_sp_goal_yaml, self_play_history | |
def ds_chatbot(user_utt): | |
turn_id = len(ds_history) | |
sys_response = sys_model.response(user_utt, turn_id) | |
sys_response = sys_response[0].upper() + sys_response[1:] | |
ds_history.append((user_utt, sys_response)) | |
return ds_history | |
def us_chatbot(sys_response): | |
user_utt = user_model.response(sys_response) | |
us_history.append((sys_response, user_utt)) | |
if user_model.is_terminated(): | |
change_goal() | |
return us_history | |
def self_play(): | |
if len(self_play_history) == 0: | |
sys_response = "" | |
else: | |
sys_response = self_play_history[-1][1] | |
user_utt = self_play_user_model.response(sys_response) | |
turn_id = len(self_play_history) | |
sys_response = self_play_sys_model.response(user_utt, turn_id) | |
sys_response = sys_response[0].upper() + sys_response[1:] | |
self_play_history.append((user_utt, sys_response)) | |
if user_model.is_terminated(): | |
change_goal() | |
return self_play_history | |
# Reset state upon client-side refresh | |
reset_ds_state() | |
reset_us_state() | |
reset_self_play_state() | |
# Initialise demo render | |
block = gr.Blocks() | |
with block: | |
gr.Markdown("# π¬ Jointly Optimized Task-Oriented Dialogue System And User Simulator π¬") | |
gr.Markdown( | |
"Created by [Alistair McLeay](https://alistairmcleay.com) for the [Masters in Machine Learning & Machine Intelligence at Cambridge University](https://www.mlmi.eng.cam.ac.uk/). <br/>\ | |
Thank you to [Professor Bill Byrne](https://sites.google.com/view/bill-byrne/home) for his supervision and guidance. <br/> \ | |
Thank you to [Andy Tseng](https://github.com/andy194673) and [Alex Coca](https://github.com/alexcoca) who provided code and guidance." | |
) | |
gr.Markdown( | |
"Both Systems are trained on the [MultiWOZ dataset](https://github.com/budzianowski/multiwoz). <br/> \ | |
Supported domains are: <br> \ | |
1. π Train, 2. π¨ Hotel, 3. π Taxi, 4. π Police, 5. π£ Restaurant, 6. πΏ Attraction, 7. π₯ Hospital." | |
) | |
gr.Markdown( | |
"**Please note:** <br> \ | |
1. These systems are in development and are full of funny little bugs, as is this app. <br> \ | |
2. If you refresh this page the conversation state will persist. To reset a conversion you need to click 'Reset Conversation' below." | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Dialogue System"): | |
gr.Markdown( | |
"This bot is a Task-Oriented Dialogue Systen. <br> \ | |
You are the user. Go ahead and try to book a train, or a hotel etc." | |
) | |
with gr.Row(): | |
ds_input_text = gr.inputs.Textbox( | |
label="User Message", placeholder="I'd like to book a train from Cambridge to London" | |
) | |
ds_response = gr.outputs.Chatbot(label="Dialogue System Response") | |
ds_button = gr.Button("Submit Message") | |
reset_ds_button = gr.Button("Reset Conversation") | |
with gr.TabItem("User Simulator"): | |
gr.Markdown( | |
"This bot is a User Simulator. <br> \ | |
You are the Task-Oriented Dialogue System. Your job is to help the user with their requests. <br> \ | |
If you want the User Simulator to have a different goal press 'Generate New Goal'." | |
) | |
with gr.Row(): | |
us_input_text = gr.inputs.Textbox( | |
label="Dialogue System Message", placeholder="How can I help you today?" | |
) | |
us_response = gr.outputs.Chatbot(label="User Simulator Response") | |
us_button = gr.Button("Submit Message") | |
reset_us_button = gr.Button("Reset Conversation") | |
new_goal_button = gr.Button("Generate New Goal") | |
current_goal_yaml = gr.outputs.Textbox(label="New Goal (YAML)") | |
with gr.TabItem("Self-Play"): | |
gr.Markdown( | |
"In this case both the User Simulator and the Task-Oriented Dialogue System are agents. <br> \ | |
Get them to interact by pressing 'Run Next Step'. <br> \ | |
If you want the User Simulator to have a different goal press 'Generate New Goal'." | |
) | |
self_play_response = gr.outputs.Chatbot(label="Self-Play Output") | |
self_play_button = gr.Button("Run Next Step") | |
reset_self_play_button = gr.Button("Reset Conversation") | |
new_sp_goal_button = gr.Button("Generate New Goal") | |
current_sp_goal_yaml = gr.outputs.Textbox(label="New Goal (YAML)") | |
gr.Markdown("Want to get in touch? [Email me](mailto:am@alistairmcleay.com)") | |
ds_button.click(ds_chatbot, ds_input_text, ds_response) | |
us_button.click(us_chatbot, us_input_text, us_response) | |
self_play_button.click(self_play, None, self_play_response) | |
new_goal_button.click(change_goal, None, [current_goal_yaml, us_response]) | |
new_sp_goal_button.click(change_sp_goal, None, [current_sp_goal_yaml, self_play_response]) | |
reset_ds_button.click(reset_ds_state, None, ds_response) | |
reset_us_button.click(reset_us_state, None, us_response) | |
reset_self_play_button.click(reset_self_play_state, None, self_play_response) | |
block.launch() | |