Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio import ChatInterface, Request | |
from gradio.helpers import special_args | |
import anyio | |
import os | |
import threading | |
import sys | |
from itertools import chain | |
import autogen | |
from autogen.code_utils import extract_code | |
from autogen import UserProxyAgent, AssistantAgent, Agent, OpenAIWrapper | |
LOG_LEVEL = "INFO" | |
TIMEOUT = 60 | |
class myChatInterface(ChatInterface): | |
async def _submit_fn( | |
self, | |
message: str, | |
history_with_input: list[list[str | None]], | |
request: Request, | |
*args, | |
) -> tuple[list[list[str | None]], list[list[str | None]]]: | |
history = history_with_input[:-1] | |
inputs, _, _ = special_args( | |
self.fn, inputs=[message, history, *args], request=request | |
) | |
if self.is_async: | |
response = await self.fn(*inputs) | |
else: | |
response = await anyio.to_thread.run_sync( | |
self.fn, *inputs, limiter=self.limiter | |
) | |
# history.append([message, response]) | |
return history, history | |
with gr.Blocks() as demo: | |
def flatten_chain(list_of_lists): | |
return list(chain.from_iterable(list_of_lists)) | |
class thread_with_trace(threading.Thread): | |
# https://www.geeksforgeeks.org/python-different-ways-to-kill-a-thread/ | |
# https://stackoverflow.com/questions/6893968/how-to-get-the-return-value-from-a-thread | |
def __init__(self, *args, **keywords): | |
threading.Thread.__init__(self, *args, **keywords) | |
self.killed = False | |
self._return = None | |
def start(self): | |
self.__run_backup = self.run | |
self.run = self.__run | |
threading.Thread.start(self) | |
def __run(self): | |
sys.settrace(self.globaltrace) | |
self.__run_backup() | |
self.run = self.__run_backup | |
def run(self): | |
if self._target is not None: | |
self._return = self._target(*self._args, **self._kwargs) | |
def globaltrace(self, frame, event, arg): | |
if event == "call": | |
return self.localtrace | |
else: | |
return None | |
def localtrace(self, frame, event, arg): | |
if self.killed: | |
if event == "line": | |
raise SystemExit() | |
return self.localtrace | |
def kill(self): | |
self.killed = True | |
def join(self, timeout=0): | |
threading.Thread.join(self, timeout) | |
return self._return | |
def update_agent_history(recipient, messages, sender, config): | |
if config is None: | |
config = recipient | |
if messages is None: | |
messages = recipient._oai_messages[sender] | |
message = messages[-1] | |
msg = message.get("content", "") | |
# config.append(msg) if msg is not None else None # config can be agent_history | |
return False, None # required to ensure the agent communication flow continues | |
def _is_termination_msg(message): | |
"""Check if a message is a termination message. | |
Terminate when no code block is detected. Currently only detect python code blocks. | |
""" | |
if isinstance(message, dict): | |
message = message.get("content") | |
if message is None: | |
return False | |
cb = extract_code(message) | |
contain_code = False | |
for c in cb: | |
# todo: support more languages | |
if c[0] == "python": | |
contain_code = True | |
break | |
return not contain_code | |
def initialize_agents(config_list): | |
assistant = AssistantAgent( | |
name="assistant", | |
max_consecutive_auto_reply=5, | |
llm_config={ | |
# "seed": 42, | |
"timeout": TIMEOUT, | |
"config_list": config_list, | |
}, | |
) | |
userproxy = UserProxyAgent( | |
name="userproxy", | |
human_input_mode="NEVER", | |
is_termination_msg=_is_termination_msg, | |
max_consecutive_auto_reply=5, | |
# code_execution_config=False, | |
code_execution_config={ | |
"work_dir": "coding", | |
"use_docker": False, # set to True or image name like "python:3" to use docker | |
}, | |
) | |
# assistant.register_reply([Agent, None], update_agent_history) | |
# userproxy.register_reply([Agent, None], update_agent_history) | |
return assistant, userproxy | |
def chat_to_oai_message(chat_history): | |
"""Convert chat history to OpenAI message format.""" | |
messages = [] | |
if LOG_LEVEL == "DEBUG": | |
print(f"chat_to_oai_message: {chat_history}") | |
for msg in chat_history: | |
messages.append( | |
{ | |
"content": msg[0].split()[0] | |
if msg[0].startswith("exitcode") | |
else msg[0], | |
"role": "user", | |
} | |
) | |
messages.append({"content": msg[1], "role": "assistant"}) | |
return messages | |
def oai_message_to_chat(oai_messages, sender): | |
"""Convert OpenAI message format to chat history.""" | |
chat_history = [] | |
messages = oai_messages[sender] | |
if LOG_LEVEL == "DEBUG": | |
print(f"oai_message_to_chat: {messages}") | |
for i in range(0, len(messages), 2): | |
chat_history.append( | |
[ | |
messages[i]["content"], | |
messages[i + 1]["content"] if i + 1 < len(messages) else "", | |
] | |
) | |
return chat_history | |
def agent_history_to_chat(agent_history): | |
"""Convert agent history to chat history.""" | |
chat_history = [] | |
for i in range(0, len(agent_history), 2): | |
chat_history.append( | |
[ | |
agent_history[i], | |
agent_history[i + 1] if i + 1 < len(agent_history) else None, | |
] | |
) | |
return chat_history | |
def initiate_chat(config_list, user_message, chat_history): | |
if LOG_LEVEL == "DEBUG": | |
print(f"chat_history_init: {chat_history}") | |
# agent_history = flatten_chain(chat_history) | |
if len(config_list[0].get("api_key", "")) < 2: | |
chat_history.append( | |
[ | |
user_message, | |
"Hi, nice to meet you! Please enter your API keys in below text boxs.", | |
] | |
) | |
return chat_history | |
else: | |
llm_config = { | |
# "seed": 42, | |
"timeout": TIMEOUT, | |
"config_list": config_list, | |
} | |
assistant.llm_config.update(llm_config) | |
assistant.client = OpenAIWrapper(**assistant.llm_config) | |
if user_message.strip().lower().startswith("show file:"): | |
filename = user_message.strip().lower().replace("show file:", "").strip() | |
filepath = os.path.join("coding", filename) | |
if os.path.exists(filepath): | |
chat_history.append([user_message, (filepath,)]) | |
else: | |
chat_history.append([user_message, f"File {filename} not found."]) | |
return chat_history | |
assistant.reset() | |
oai_messages = chat_to_oai_message(chat_history) | |
assistant._oai_system_message_origin = assistant._oai_system_message.copy() | |
assistant._oai_system_message += oai_messages | |
try: | |
userproxy.initiate_chat(assistant, message=user_message) | |
messages = userproxy.chat_messages | |
chat_history += oai_message_to_chat(messages, assistant) | |
# agent_history = flatten_chain(chat_history) | |
except Exception as e: | |
# agent_history += [user_message, str(e)] | |
# chat_history[:] = agent_history_to_chat(agent_history) | |
chat_history.append([user_message, str(e)]) | |
assistant._oai_system_message = assistant._oai_system_message_origin.copy() | |
if LOG_LEVEL == "DEBUG": | |
print(f"chat_history: {chat_history}") | |
# print(f"agent_history: {agent_history}") | |
return chat_history | |
def chatbot_reply_thread(input_text, chat_history, config_list): | |
"""Chat with the agent through terminal.""" | |
thread = thread_with_trace( | |
target=initiate_chat, args=(config_list, input_text, chat_history) | |
) | |
thread.start() | |
try: | |
messages = thread.join(timeout=TIMEOUT) | |
if thread.is_alive(): | |
thread.kill() | |
thread.join() | |
messages = [ | |
input_text, | |
"Timeout Error: Please check your API keys and try again later.", | |
] | |
except Exception as e: | |
messages = [ | |
[ | |
input_text, | |
str(e) | |
if len(str(e)) > 0 | |
else "Invalid Request to OpenAI, please check your API keys.", | |
] | |
] | |
return messages | |
def chatbot_reply_plain(input_text, chat_history, config_list): | |
"""Chat with the agent through terminal.""" | |
try: | |
messages = initiate_chat(config_list, input_text, chat_history) | |
except Exception as e: | |
messages = [ | |
[ | |
input_text, | |
str(e) | |
if len(str(e)) > 0 | |
else "Invalid Request to OpenAI, please check your API keys.", | |
] | |
] | |
return messages | |
def chatbot_reply(input_text, chat_history, config_list): | |
"""Chat with the agent through terminal.""" | |
return chatbot_reply_thread(input_text, chat_history, config_list) | |
def get_description_text(): | |
return """ | |
# Microsoft AutoGen: Multi-Round Human Interaction Chatbot Demo | |
This demo shows how to build a chatbot which can handle multi-round conversations with human interactions. | |
#### [AutoGen](https://github.com/microsoft/autogen) [Discord](https://discord.gg/pAbnFJrkgZ) [Paper](https://arxiv.org/abs/2308.08155) [SourceCode](https://github.com/thinkall/autogen-demos) | |
""" | |
def update_config(): | |
config_list = autogen.config_list_from_models( | |
model_list=[os.environ.get("MODEL", "gpt-35-turbo")], | |
) | |
if not config_list: | |
config_list = [ | |
{ | |
"api_key": "", | |
"base_url": "", | |
"api_type": "azure", | |
"api_version": "2023-07-01-preview", | |
"model": "gpt-35-turbo", | |
} | |
] | |
return config_list | |
def set_params(model, oai_key, aoai_key, aoai_base): | |
os.environ["MODEL"] = model | |
os.environ["OPENAI_API_KEY"] = oai_key | |
os.environ["AZURE_OPENAI_API_KEY"] = aoai_key | |
os.environ["AZURE_OPENAI_API_BASE"] = aoai_base | |
def respond(message, chat_history, model, oai_key, aoai_key, aoai_base): | |
set_params(model, oai_key, aoai_key, aoai_base) | |
config_list = update_config() | |
chat_history[:] = chatbot_reply(message, chat_history, config_list) | |
if LOG_LEVEL == "DEBUG": | |
print(f"return chat_history: {chat_history}") | |
return "" | |
config_list, assistant, userproxy = ( | |
[ | |
{ | |
"api_key": "", | |
"base_url": "", | |
"api_type": "azure", | |
"api_version": "2023-07-01-preview", | |
"model": "gpt-35-turbo", | |
} | |
], | |
None, | |
None, | |
) | |
assistant, userproxy = initialize_agents(config_list) | |
description = gr.Markdown(get_description_text()) | |
with gr.Row() as params: | |
txt_model = gr.Dropdown( | |
label="Model", | |
choices=[ | |
"gpt-4", | |
"gpt-35-turbo", | |
"gpt-3.5-turbo", | |
], | |
allow_custom_value=True, | |
value="gpt-35-turbo", | |
container=True, | |
) | |
txt_oai_key = gr.Textbox( | |
label="OpenAI API Key", | |
placeholder="Enter OpenAI API Key", | |
max_lines=1, | |
show_label=True, | |
container=True, | |
type="password", | |
) | |
txt_aoai_key = gr.Textbox( | |
label="Azure OpenAI API Key", | |
placeholder="Enter Azure OpenAI API Key", | |
max_lines=1, | |
show_label=True, | |
container=True, | |
type="password", | |
) | |
txt_aoai_base_url = gr.Textbox( | |
label="Azure OpenAI API Base", | |
placeholder="Enter Azure OpenAI Base Url", | |
max_lines=1, | |
show_label=True, | |
container=True, | |
type="password", | |
) | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
avatar_images=( | |
"human.png", | |
(os.path.join(os.path.dirname(__file__), "autogen.png")), | |
), | |
render=False, | |
height=800, | |
) | |
txt_input = gr.Textbox( | |
scale=4, | |
show_label=False, | |
placeholder="Enter text and press enter", | |
container=False, | |
render=False, | |
autofocus=True, | |
) | |
chatiface = myChatInterface( | |
respond, | |
chatbot=chatbot, | |
textbox=txt_input, | |
additional_inputs=[ | |
txt_model, | |
txt_oai_key, | |
txt_aoai_key, | |
txt_aoai_base_url, | |
], | |
examples=[ | |
["write a python function to count the sum of two numbers?"], | |
["what if the production of two numbers?"], | |
[ | |
"Plot a chart of the last year's stock prices of Microsoft, Google and Apple and save to stock_price.png." | |
], | |
["show file: stock_price.png"], | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True, server_name="0.0.0.0") | |