agora-demo / app.py
samuelemarro's picture
Added schemas.json.
19bcd88
raw
history blame
7.61 kB
import json
import gradio as gr
from collections import UserList
from flow import full_flow
from utils import use_cost_tracker, get_costs, compute_hash
with open('schemas.json', 'r') as f:
SCHEMAS = json.load(f)
def parse_raw_messages(messages_raw):
messages_clean = []
messages_agora = []
for message in messages_raw:
role = message['role']
message_without_role = dict(message)
del message_without_role['role']
messages_agora.append({
'role': role,
'content': '```\n' + json.dumps(message_without_role, indent=2) + '\n```'
})
if message.get('status') == 'error':
messages_clean.append({
'role': role,
'content': f"Error: {message['message']}"
})
else:
messages_clean.append({
'role': role,
'content': message['body']
})
return messages_clean, messages_agora
def main():
with gr.Blocks() as demo:
gr.Markdown("### Agora Demo")
gr.Markdown("We will create a new Agora channel and offer it to Alice as a tool.")
chosen_task = gr.Dropdown(choices=list(SCHEMAS.keys()), label="Schema", value="weather_forecast")
custom_task = gr.Checkbox(label="Custom Task")
STATE_TRACKER = {}
@gr.render(inputs=[chosen_task, custom_task])
def render(chosen_task, custom_task):
if STATE_TRACKER.get('chosen_task') != chosen_task:
STATE_TRACKER['chosen_task'] = chosen_task
for k, v in SCHEMAS[chosen_task].items():
if isinstance(v, str):
STATE_TRACKER[k] = v
else:
STATE_TRACKER[k] = json.dumps(v, indent=2)
if custom_task:
gr.Text(label="Description", value=STATE_TRACKER["description"], interactive=True).change(lambda x: STATE_TRACKER.update({'description': x}))
gr.TextArea(label="Input Schema", value=STATE_TRACKER["input"], interactive=True).change(lambda x: STATE_TRACKER.update({'input': x}))
gr.TextArea(label="Output Schema", value=STATE_TRACKER["output"], interactive=True).change(lambda x: STATE_TRACKER.update({'output': x}))
gr.TextArea(label="Tools", value=STATE_TRACKER["tools"], interactive=True).change(lambda x: STATE_TRACKER.update({'tools': x}))
gr.TextArea(label="Examples", value=STATE_TRACKER["examples"], interactive=True).change(lambda x: STATE_TRACKER.update({'examples': x}))
model_options = [
('GPT 4o (Camel AI)', 'gpt-4o'),
('GPT 4o-mini (Camel AI)', 'gpt-4o-mini'),
('Claude 3 Sonnet (LangChain)', 'claude-3-sonnet'),
('Gemini 1.5 Pro (Google GenAI)', 'gemini-1.5-pro'),
('Llama3 405B (Sambanova + LangChain)', 'llama3-405b')
]
fallback_image = ''
images = {
'gpt-4o': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/chatgpt-icon.png',
'gpt-4o-mini': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/chatgpt-icon.png',
'claude-3-5-sonnet-latest': 'https://play-lh.googleusercontent.com/4S1nfdKsH_1tJodkHrBHimqlCTE6qx6z22zpMyPaMc_Rlr1EdSFDI1I6UEVMnokG5zI',
'claude-3-5-haiku-latest': 'https://play-lh.googleusercontent.com/4S1nfdKsH_1tJodkHrBHimqlCTE6qx6z22zpMyPaMc_Rlr1EdSFDI1I6UEVMnokG5zI',
'gemini-1.5-pro': 'https://uxwing.com/wp-content/themes/uxwing/download/brands-and-social-media/google-gemini-icon.png',
'llama3-405b': 'https://www.designstub.com/png-resources/wp-content/uploads/2023/03/meta-icon-social-media-flat-graphic-vector-3-novem.png'
}
with gr.Row(equal_height=True):
with gr.Column(scale=1):
alice_model_dd = gr.Dropdown(label="Alice Model", choices=model_options, value="gpt-4o")
with gr.Column(scale=1):
bob_model_dd = gr.Dropdown(label="Bob Model", choices=model_options, value="gpt-4o")
button = gr.Button('Start', elem_id='start_button')
gr.Markdown('### Natural Language')
@gr.render(inputs=[alice_model_dd, bob_model_dd])
def render_with_images(alice_model, bob_model):
avatar_images = [images.get(alice_model, fallback_image), images.get(bob_model, fallback_image)]
chatbot_nl = gr.Chatbot(type="messages", avatar_images=avatar_images)
with gr.Accordion(label="Raw Messages", open=False):
chatbot_nl_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
gr.Markdown('### Negotiation')
chatbot_negotiation = gr.Chatbot(type="messages", avatar_images=avatar_images)
gr.Markdown('### Protocol')
protocol_result = gr.TextArea(interactive=False, label="Protocol")
gr.Markdown('### Implementation')
with gr.Row():
with gr.Column(scale=1):
alice_implementation = gr.TextArea(interactive=False, label="Alice Implementation")
with gr.Column(scale=1):
bob_implementation = gr.TextArea(interactive=False, label="Bob Implementation")
gr.Markdown('### Structured Communication')
structured_communication = gr.Chatbot(type="messages", avatar_images=avatar_images)
with gr.Accordion(label="Raw Messages", open=False):
structured_communication_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
def respond(chosen_task, custom_task, alice_model, bob_model):
yield gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), \
None, None, None, None, None, None, None, None
if custom_task:
schema = dict(STATE_TRACKER)
for k, v in schema.items():
if isinstance(v, str):
try:
schema[k] = json.loads(v)
except:
pass
else:
schema = SCHEMAS[chosen_task]
for nl_messages_raw, negotiation_messages, structured_messages_raw, protocol, alice_implementation, bob_implementation in full_flow(schema, alice_model, bob_model):
nl_messages_clean, nl_messages_agora = parse_raw_messages(nl_messages_raw)
structured_messages_clean, structured_messages_agora = parse_raw_messages(structured_messages_raw)
yield gr.update(), gr.update(), gr.update(), nl_messages_clean, nl_messages_agora, negotiation_messages, structured_messages_clean, structured_messages_agora, protocol, alice_implementation, bob_implementation
#yield from full_flow(schema, alice_model, bob_model)
yield gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
button.click(respond, [chosen_task, custom_task, alice_model_dd, bob_model_dd], [button, alice_model_dd, bob_model_dd, chatbot_nl, chatbot_nl_raw, chatbot_negotiation, structured_communication, structured_communication_raw, protocol_result, alice_implementation, bob_implementation])
demo.launch()
if __name__ == '__main__':
main()