Spaces:
Running
Running
import os | |
import asyncio | |
import logging | |
import threading | |
import queue | |
import gradio as gr | |
import httpx | |
import time | |
import tempfile | |
from typing import Generator, Any, Dict, List, Optional | |
# -------------------- Configuration -------------------- | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
# -------------------- External Model Call (with Caching and Retry) -------------------- | |
async def call_model(prompt: str, model: str = "gpt-4o", api_key: str = None, max_retries: int = 3) -> str: | |
if api_key is None: | |
api_key = os.getenv("OPENAI_API_KEY") | |
if api_key is None: | |
raise ValueError("OpenAI API key not provided.") | |
url = "https://api.openai.com/v1/chat/completions" | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json", | |
} | |
payload = { | |
"model": model, | |
"messages": [{"role": "user", "content": prompt}], | |
} | |
for attempt in range(max_retries): | |
try: | |
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client: | |
response = await client.post(url, headers=headers, json=payload) | |
response.raise_for_status() | |
response_json = response.json() | |
return response_json["choices"][0]["message"]["content"] | |
except httpx.HTTPStatusError as e: | |
logging.error(f"HTTP error (attempt {attempt + 1}/{max_retries}): {e}") | |
if e.response.status_code in (502, 503, 504): | |
await asyncio.sleep(2 ** attempt) | |
continue | |
else: | |
raise | |
except httpx.RequestError as e: | |
logging.error(f"Request error (attempt {attempt + 1}/{max_retries}): {e}") | |
await asyncio.sleep(2 ** attempt) | |
continue | |
except Exception as e: | |
logging.error(f"Unexpected error (attempt {attempt+1}/{max_retries}): {e}") | |
raise | |
raise Exception(f"Failed to get response after {max_retries} attempts.") | |
# -------------------- Conversation History Conversion -------------------- | |
def convert_history(history: List[Dict[str, str]]) -> List[Dict[str, str]]: | |
""" | |
Convert our internal conversation history (with 'agent' and 'message') | |
into the Gradio messages format (with 'role' and 'content'). | |
""" | |
converted = [] | |
for entry in history: | |
if entry["agent"].lower() == "user": | |
converted.append({"role": "user", "content": entry["message"]}) | |
else: | |
converted.append({"role": "assistant", "content": f'{entry["agent"]}: {entry["message"]}'}) | |
return converted | |
def conversation_to_text(history: List[Dict[str, str]]) -> str: | |
""" | |
Convert the conversation history to a plain-text log. | |
""" | |
lines = [] | |
for entry in history: | |
lines.append(f"{entry['agent']}: {entry['message']}") | |
return "\n".join(lines) | |
# -------------------- Shared Context -------------------- | |
class Context: | |
def __init__(self, original_task: str, optimized_task: Optional[str] = None, | |
plan: Optional[str] = None, code: Optional[str] = None, | |
review_comments: Optional[List[Dict[str, str]]] = None, | |
test_cases: Optional[str] = None, test_results: Optional[str] = None, | |
documentation: Optional[str] = None, conversation_history: Optional[List[Dict[str, str]]] = None): | |
self.original_task = original_task | |
self.optimized_task = optimized_task | |
self.plan = plan | |
self.code = code | |
self.review_comments = review_comments or [] | |
self.test_cases = test_cases | |
self.test_results = test_results | |
self.documentation = documentation | |
# Initialize with the user's task. | |
self.conversation_history = conversation_history or [{"agent": "User", "message": original_task}] | |
def add_conversation_entry(self, agent_name: str, message: str): | |
self.conversation_history.append({"agent": agent_name, "message": message}) | |
# -------------------- Agent Classes -------------------- | |
class PromptOptimizerAgent: | |
async def optimize_prompt(self, context: Context, api_key: str) -> Context: | |
system_prompt = ( | |
"Improve the prompt. Be clear, specific, and complete. " | |
"Keep original intent. Return ONLY the revised prompt." | |
) | |
full_prompt = f"{system_prompt}\n\nUser's prompt:\n{context.original_task}" | |
optimized = await call_model(full_prompt, model="gpt-4o", api_key=api_key) | |
context.optimized_task = optimized | |
context.add_conversation_entry("Prompt Optimizer", f"Optimized Task:\n{optimized}") | |
return context | |
class OrchestratorAgent: | |
def __init__(self, log_queue: queue.Queue, human_event: threading.Event, human_input_queue: queue.Queue): | |
self.log_queue = log_queue | |
self.human_event = human_event | |
self.human_input_queue = human_input_queue | |
async def generate_plan(self, context: Context, api_key: str) -> Context: | |
while True: | |
if context.plan: | |
prompt = ( | |
f"You are a planner. Revise/complete the plan for '{context.original_task}'. " | |
"If unsure, output 'REQUEST_HUMAN_FEEDBACK\\n[Question]'" | |
) | |
else: | |
prompt = ( | |
f"You are a planner. Create a plan for: '{context.optimized_task}'. " | |
"Break down the task and assign sub-tasks to: Coder, Code Reviewer, Quality Assurance Tester, and Documentation Agent. " | |
"Include review/revision steps, error handling, and documentation instructions.\n\n" | |
"If unsure, output 'REQUEST_HUMAN_FEEDBACK\\n[Question]'" | |
) | |
plan = await call_model(prompt, model="gpt-4o", api_key=api_key) | |
context.add_conversation_entry("Orchestrator", f"Plan:\n{plan}") | |
self.log_queue.put(("update", context.conversation_history)) | |
if "REQUEST_HUMAN_FEEDBACK" in plan: | |
question = plan.split("REQUEST_HUMAN_FEEDBACK\n", 1)[1].strip() | |
self.log_queue.put(("[Orchestrator]", f"Requesting human feedback... Question: {question}")) | |
feedback_context = ( | |
f"Task: {context.optimized_task}\nCurrent Plan: {context.plan or 'None'}\nQuestion: {question}" | |
) | |
self.human_event.set() | |
self.human_input_queue.put(feedback_context) | |
human_response = self.human_input_queue.get() # Blocking waiting for human response | |
self.human_event.clear() | |
self.log_queue.put(("[Orchestrator]", f"Received human feedback: {human_response}")) | |
context.plan = (context.plan + "\n" + human_response) if context.plan else human_response | |
else: | |
context.plan = plan | |
break | |
return context | |
class CoderAgent: | |
async def generate_code(self, context: Context, api_key: str, model: str = "gpt-4o") -> Context: | |
prompt = ( | |
"You are a coding agent. Output ONLY the code. " | |
"Adhere to best practices and include error handling.\n\n" | |
f"Instructions:\n{context.plan}" | |
) | |
code = await call_model(prompt, model=model, api_key=api_key) | |
context.code = code | |
context.add_conversation_entry("Coder", f"Code:\n{code}") | |
return context | |
class CodeReviewerAgent: | |
async def review_code(self, context: Context, api_key: str) -> Context: | |
prompt = ( | |
"You are a code reviewer. Provide CONCISE feedback focusing on correctness, efficiency, readability, error handling, and security. " | |
"If the code is acceptable, respond with ONLY 'APPROVE'. Do NOT generate code.\n\n" | |
f"Task: {context.optimized_task}\n\nCode:\n{context.code}" | |
) | |
review = await call_model(prompt, model="gpt-4o", api_key=api_key) | |
context.add_conversation_entry("Code Reviewer", f"Review:\n{review}") | |
if "APPROVE" not in review.upper(): | |
structured_review = {"comments": []} | |
for line in review.splitlines(): | |
if line.strip(): | |
structured_review["comments"].append({ | |
"issue": line.strip(), | |
"line_number": "N/A", | |
"severity": "Medium" | |
}) | |
context.review_comments.append(structured_review) | |
return context | |
class QualityAssuranceTesterAgent: | |
async def generate_test_cases(self, context: Context, api_key: str) -> Context: | |
prompt = ( | |
"You are a testing agent. Generate comprehensive test cases considering edge cases and error scenarios. " | |
"Output in a clear format.\n\n" | |
f"Task: {context.optimized_task}\n\nCode:\n{context.code}" | |
) | |
test_cases = await call_model(prompt, model="gpt-4o", api_key=api_key) | |
context.test_cases = test_cases | |
context.add_conversation_entry("QA Tester", f"Test Cases:\n{test_cases}") | |
return context | |
async def run_tests(self, context: Context, api_key: str) -> Context: | |
prompt = ( | |
"Run the test cases. Compare actual vs expected outputs and state any discrepancies. " | |
"If all tests pass, output 'TESTS PASSED'.\n\n" | |
f"Code:\n{context.code}\n\nTest Cases:\n{context.test_cases}" | |
) | |
test_results = await call_model(prompt, model="gpt-4o", api_key=api_key) | |
context.test_results = test_results | |
context.add_conversation_entry("QA Tester", f"Test Results:\n{test_results}") | |
return context | |
class DocumentationAgent: | |
async def generate_documentation(self, context: Context, api_key: str) -> Context: | |
prompt = ( | |
"Generate clear documentation including a brief description, explanation, and a --help message.\n\n" | |
f"Code:\n{context.code}" | |
) | |
documentation = await call_model(prompt, model="gpt-4o", api_key=api_key) | |
context.documentation = documentation | |
context.add_conversation_entry("Documentation Agent", f"Documentation:\n{documentation}") | |
return context | |
# -------------------- Agent Dispatcher -------------------- | |
class AgentDispatcher: | |
def __init__(self, log_queue: queue.Queue, human_event: threading.Event, human_input_queue: queue.Queue): | |
self.log_queue = log_queue | |
self.human_event = human_event | |
self.human_input_queue = human_input_queue | |
self.agents = { | |
"prompt_optimizer": PromptOptimizerAgent(), | |
"orchestrator": OrchestratorAgent(log_queue, human_event, human_input_queue), | |
"coder": CoderAgent(), | |
"code_reviewer": CodeReviewerAgent(), | |
"qa_tester": QualityAssuranceTesterAgent(), | |
"documentation_agent": DocumentationAgent(), | |
} | |
async def dispatch(self, agent_name: str, context: Context, api_key: str, **kwargs) -> Context: | |
self.log_queue.put((f"[{agent_name.replace('_', ' ').title()}]", "Starting task...")) | |
if agent_name == "prompt_optimizer": | |
context = await self.agents[agent_name].optimize_prompt(context, api_key) | |
elif agent_name == "orchestrator": | |
context = await self.agents[agent_name].generate_plan(context, api_key) | |
elif agent_name == "coder": | |
context = await self.agents[agent_name].generate_code(context, api_key, **kwargs) | |
elif agent_name == "code_reviewer": | |
context = await self.agents[agent_name].review_code(context, api_key) | |
elif agent_name == "qa_tester": | |
if kwargs.get("generate_tests", False): | |
context = await self.agents[agent_name].generate_test_cases(context, api_key) | |
elif kwargs.get("run_tests", False): | |
context = await self.agents[agent_name].run_tests(context, api_key) | |
elif agent_name == "documentation_agent": | |
context = await self.agents[agent_name].generate_documentation(context, api_key) | |
else: | |
raise ValueError(f"Unknown agent: {agent_name}") | |
self.log_queue.put(("update", context.conversation_history)) | |
return context | |
async def determine_next_agent(self, context: Context, api_key: str) -> str: | |
if not context.optimized_task: | |
return "prompt_optimizer" | |
if not context.plan: | |
return "orchestrator" | |
if not context.code: | |
return "coder" | |
if not any("APPROVE" in entry["message"].upper() | |
for entry in context.conversation_history | |
if entry["agent"].lower() == "code reviewer"): | |
return "code_reviewer" | |
if not context.test_cases: | |
return "qa_tester" | |
if not context.test_results or "TESTS PASSED" not in context.test_results.upper(): | |
return "qa_tester" | |
if not context.documentation: | |
return "documentation_agent" | |
return "done" | |
# -------------------- Multi-Agent Conversation -------------------- | |
async def multi_agent_conversation(task_message: str, log_queue: queue.Queue, api_key: str, | |
human_event: threading.Event, human_input_queue: queue.Queue) -> None: | |
context = Context(original_task=task_message) | |
dispatcher = AgentDispatcher(log_queue, human_event, human_input_queue) | |
next_agent = await dispatcher.determine_next_agent(context, api_key) | |
coder_iterations = 0 | |
while next_agent != "done": | |
if next_agent == "qa_tester": | |
if not context.test_cases: | |
context = await dispatcher.dispatch(next_agent, context, api_key, generate_tests=True) | |
else: | |
context = await dispatcher.dispatch(next_agent, context, api_key, run_tests=True) | |
elif next_agent == "coder" and (context.review_comments or context.test_results): | |
coder_iterations += 1 | |
context = await dispatcher.dispatch(next_agent, context, api_key, model="gpt-3.5-turbo-16k") | |
else: | |
context = await dispatcher.dispatch(next_agent, context, api_key) | |
if next_agent == "code_reviewer": | |
approved = any("APPROVE" in entry["message"].upper() | |
for entry in context.conversation_history | |
if entry["agent"].lower() == "code reviewer") | |
if approved: | |
next_agent = await dispatcher.determine_next_agent(context, api_key) | |
else: | |
next_agent = "coder" | |
else: | |
next_agent = await dispatcher.determine_next_agent(context, api_key) | |
if next_agent == "coder" and coder_iterations > 5: | |
log_queue.put(("[System]", "Maximum revision iterations reached. Exiting.")) | |
break | |
log_queue.put(("result", context.conversation_history)) | |
# -------------------- Process Conversation Generator -------------------- | |
def process_conversation_generator(task_message: str, api_key: str, | |
human_event: threading.Event, human_input_queue: queue.Queue, | |
log_queue: queue.Queue) -> Generator[Any, None, None]: | |
""" | |
Runs the multi-agent conversation in a background thread and yields conversation history updates | |
as a tuple: (chat update, log state update). | |
""" | |
last_log_text = "" | |
def run_conversation(): | |
asyncio.run(multi_agent_conversation(task_message, log_queue, api_key, human_event, human_input_queue)) | |
conversation_thread = threading.Thread(target=run_conversation) | |
conversation_thread.start() | |
while conversation_thread.is_alive() or not log_queue.empty(): | |
try: | |
msg = log_queue.get(timeout=0.1) | |
if isinstance(msg, tuple) and msg[0] in ("update", "result"): | |
chat_update = gr.update(value=convert_history(msg[1]), visible=True) | |
last_log_text = conversation_to_text(msg[1]) | |
state_update = gr.update(value=last_log_text) | |
yield (chat_update, state_update) | |
else: | |
pass | |
except queue.Empty: | |
pass | |
time.sleep(0.1) | |
yield (gr.update(visible=True), gr.update(value=last_log_text)) | |
# -------------------- Multi-Agent Chat Function -------------------- | |
def multi_agent_chat(message: str, openai_api_key: str = None) -> Generator[Any, None, None]: | |
if not openai_api_key: | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
if not openai_api_key: | |
yield (gr.update(value=[{"role": "assistant", "content": "Error: API key not provided."}]), gr.update()) | |
return | |
human_event = threading.Event() | |
human_input_queue = queue.Queue() | |
log_queue = queue.Queue() | |
yield from process_conversation_generator(message, openai_api_key, human_event, human_input_queue, log_queue) | |
# -------------------- Download Log Function -------------------- | |
def download_log(log_text: str) -> str: | |
""" | |
Writes the log text to a temporary file and returns the file path. | |
""" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8") as f: | |
f.write(log_text) | |
return f.name | |
# -------------------- Custom Gradio Blocks Interface -------------------- | |
css = ''' | |
#gen_btn{height: 100%} | |
#gen_column{align-self: stretch} | |
#title{text-align: center} | |
#title h1{font-size: 3em; display:inline-flex; align-items:center} | |
#title img{width: 100px; margin-right: 0.5em} | |
#gallery .grid-wrap{height: 10vh} | |
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%} | |
.card_internal{display: flex;height: 100px;margin-top: .5em} | |
.card_internal img{margin-right: 1em}.styler{--form-gap-width: 0px !important} | |
#progress{height:30px}#progress .generating{display:none}.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out} | |
/* Add this to make the chatbot bigger */ | |
.chat-container { | |
height: 600px; /* Adjust as needed */ | |
overflow-y: scroll; /* Add scrollbar if content overflows */ | |
} | |
''' | |
with gr.Blocks(theme="CultriX/gradio-theme", css=css, delete_cache=(60, 60)) as demo: | |
gr.Markdown("## Multi-Agent Task Solver with Human-in-the-Loop") | |
with gr.Row(): | |
with gr.Column(): # Add a column for better layout | |
chat_output = gr.Chatbot(label="Conversation", type="messages") | |
chat_output.wrap = gr.HTML("<div class='chat-container'></div>") # Wrap after creation | |
# Hidden state to store the plain-text log. | |
log_state = gr.State(value="") | |
with gr.Row(): | |
with gr.Column(scale=8): | |
message_input = gr.Textbox(label="Enter your task", placeholder="Type your task here...", lines=3) | |
with gr.Column(scale=2): | |
api_key_input = gr.Textbox(label="API Key (optional)", type="password", placeholder="Leave blank to use env variable") | |
send_button = gr.Button("Send") | |
# The multi_agent_chat function now outputs two values: one for the chat and one for the log. | |
send_button.click(fn=multi_agent_chat, inputs=[message_input, api_key_input], outputs=[chat_output, log_state]) | |
with gr.Row(): | |
download_button = gr.Button("Download Log") | |
download_file = gr.File(label="Download your log file") | |
download_button.click(fn=download_log, inputs=log_state, outputs=download_file) | |
if __name__ == "__main__": | |
demo.launch(share=True) |