|
|
|
|
|
from __future__ import annotations |
|
|
|
import asyncio |
|
import contextlib |
|
import os |
|
import pathlib |
|
import re |
|
import shutil |
|
import time |
|
import traceback |
|
import uuid |
|
from collections import deque |
|
from contextlib import asynccontextmanager |
|
from functools import partial |
|
from typing import Dict |
|
|
|
import fire |
|
import tenacity |
|
import uvicorn |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse, StreamingResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from loguru import logger |
|
from metagpt.config import CONFIG |
|
from metagpt.logs import set_llm_stream_logfunc |
|
from metagpt.schema import Message |
|
from metagpt.utils.common import any_to_name, any_to_str |
|
from openai import OpenAI |
|
|
|
from data_model import ( |
|
LLMAPIkeyTest, |
|
MessageJsonModel, |
|
NewMsg, |
|
Sentence, |
|
Sentences, |
|
SentenceType, |
|
SentenceValue, |
|
ThinkActPrompt, |
|
ThinkActStep, |
|
) |
|
from message_enum import MessageStatus, QueryAnswerType |
|
from software_company import RoleRun, SoftwareCompany |
|
|
|
|
|
class Service: |
|
@classmethod |
|
async def create_message(cls, req_model: NewMsg, request: Request): |
|
""" |
|
Session message stream |
|
""" |
|
tc_id = 0 |
|
task = None |
|
try: |
|
exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) |
|
config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} |
|
cls._set_context(config) |
|
msg_queue = deque() |
|
CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None |
|
|
|
role = SoftwareCompany() |
|
role.recv(message=Message(content=req_model.query)) |
|
answer = MessageJsonModel( |
|
steps=[ |
|
Sentences( |
|
contents=[ |
|
Sentence( |
|
type=SentenceType.TEXT.value, |
|
value=SentenceValue(answer=req_model.query), |
|
is_finished=True, |
|
).model_dump() |
|
], |
|
status=MessageStatus.COMPLETE.value, |
|
) |
|
], |
|
qa_type=QueryAnswerType.Answer.value, |
|
) |
|
|
|
async def stop_if_disconnect(): |
|
while not await request.is_disconnected(): |
|
await asyncio.sleep(1) |
|
|
|
if task is None: |
|
return |
|
|
|
if not task.done(): |
|
task.cancel() |
|
logger.info(f"cancel task {task}") |
|
|
|
asyncio.create_task(stop_if_disconnect()) |
|
|
|
while True: |
|
tc_id += 1 |
|
if await request.is_disconnected(): |
|
return |
|
think_result: RoleRun = await role.think() |
|
if not think_result: |
|
break |
|
|
|
think_act_prompt = ThinkActPrompt(role=think_result.role.profile) |
|
think_act_prompt.update_think(tc_id, think_result) |
|
yield think_act_prompt.prompt + "\n\n" |
|
task = asyncio.create_task(role.act()) |
|
|
|
while not await request.is_disconnected(): |
|
if msg_queue: |
|
think_act_prompt.update_act(msg_queue.pop(), False) |
|
yield think_act_prompt.prompt + "\n\n" |
|
continue |
|
|
|
if task.done(): |
|
break |
|
|
|
await asyncio.sleep(0.5) |
|
else: |
|
task.cancel() |
|
return |
|
|
|
act_result = await task |
|
think_act_prompt.update_act(act_result) |
|
yield think_act_prompt.prompt + "\n\n" |
|
answer.add_think_act(think_act_prompt) |
|
yield answer.prompt + "\n\n" |
|
except asyncio.CancelledError: |
|
task.cancel() |
|
except tenacity.RetryError as retry_error: |
|
yield cls.handle_retry_error(tc_id, retry_error) |
|
except Exception as ex: |
|
description = str(ex) |
|
answer = traceback.format_exc() |
|
think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) |
|
yield think_act_prompt.prompt + "\n\n" |
|
finally: |
|
CONFIG.WORKSPACE_PATH: pathlib.Path |
|
if CONFIG.WORKSPACE_PATH.exists(): |
|
shutil.rmtree(CONFIG.WORKSPACE_PATH) |
|
|
|
@staticmethod |
|
def create_error_think_act_prompt(tc_id: int, title, description: str, answer: str) -> ThinkActPrompt: |
|
step = ThinkActStep( |
|
id=tc_id, |
|
status="failed", |
|
title=title, |
|
description=description, |
|
content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), |
|
) |
|
return ThinkActPrompt(step=step) |
|
|
|
@classmethod |
|
def handle_retry_error(cls, tc_id: int, error: tenacity.RetryError): |
|
|
|
try: |
|
|
|
original_exception = error.last_attempt.exception() |
|
while isinstance(original_exception, tenacity.RetryError): |
|
original_exception = original_exception.last_attempt.exception() |
|
|
|
name = any_to_str(original_exception) |
|
if re.match(r"^openai\.", name): |
|
return cls._handle_openai_error(tc_id, original_exception) |
|
|
|
if re.match(r"^httpx\.", name): |
|
return cls._handle_httpx_error(tc_id, original_exception) |
|
|
|
if re.match(r"^json\.", name): |
|
return cls._handle_json_error(tc_id, original_exception) |
|
|
|
return cls.handle_unexpected_error(tc_id, error) |
|
except Exception: |
|
return cls.handle_unexpected_error(tc_id, error) |
|
|
|
@classmethod |
|
def _handle_openai_error(cls, tc_id, original_exception): |
|
answer = original_exception.message |
|
title = f"OpenAI {any_to_name(original_exception)}" |
|
think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) |
|
return think_act_prompt.prompt + "\n\n" |
|
|
|
@classmethod |
|
def _handle_httpx_error(cls, tc_id, original_exception): |
|
answer = f"{original_exception}. {original_exception.request}" |
|
title = f"httpx {any_to_name(original_exception)}" |
|
think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, title, answer) |
|
return think_act_prompt.prompt + "\n\n" |
|
|
|
@classmethod |
|
def _handle_json_error(cls, tc_id, original_exception): |
|
answer = str(original_exception) |
|
title = "MetaGPT Action Node Error" |
|
description = f"LLM response parse error. {any_to_str(original_exception)}: {original_exception}" |
|
think_act_prompt = cls.create_error_think_act_prompt(tc_id, title, description, answer) |
|
return think_act_prompt.prompt + "\n\n" |
|
|
|
@classmethod |
|
def handle_unexpected_error(cls, tc_id, error): |
|
description = str(error) |
|
answer = traceback.format_exc() |
|
think_act_prompt = cls.create_error_think_act_prompt(tc_id, description, description, answer) |
|
return think_act_prompt.prompt + "\n\n" |
|
|
|
@staticmethod |
|
def _set_context(context: Dict) -> Dict: |
|
uid = uuid.uuid4().hex |
|
context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) |
|
for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): |
|
if old in context and new not in context: |
|
context[new] = context[old] |
|
CONFIG.set_context(context) |
|
return context |
|
|
|
|
|
default_llm_stream_log = partial(print, end="") |
|
|
|
|
|
def llm_stream_log(msg): |
|
with contextlib.suppress(): |
|
CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) |
|
|
|
|
|
class ChatHandler: |
|
@staticmethod |
|
async def create_message(req_model: NewMsg, request: Request): |
|
"""Message stream, using SSE.""" |
|
event = Service.create_message(req_model, request) |
|
headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} |
|
return StreamingResponse(event, headers=headers, media_type="text/event-stream") |
|
|
|
|
|
class LLMAPIHandler: |
|
@staticmethod |
|
async def check_openai_key(req_model: LLMAPIkeyTest): |
|
try: |
|
|
|
client = OpenAI(api_key=req_model.api_key) |
|
response = client.models.list() |
|
model_set = {model.id for model in response.data} |
|
if req_model.llm_type in model_set: |
|
logger.info("API Key is valid.") |
|
return JSONResponse({"valid": True}) |
|
else: |
|
logger.info("API Key is invalid.") |
|
return JSONResponse({"valid": False, "message": "Model not found"}) |
|
except Exception as e: |
|
|
|
logger.info(f"Error: {e}") |
|
return JSONResponse({"valid": False, "message": str(e)}) |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
loop = asyncio.get_running_loop() |
|
loop.create_task(clear_storage()) |
|
yield |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
app.mount( |
|
"/storage", |
|
StaticFiles(directory="./storage/"), |
|
name="storage", |
|
) |
|
|
|
app.add_api_route( |
|
"/api/messages", |
|
endpoint=ChatHandler.create_message, |
|
methods=["post"], |
|
summary="Session message sending (streaming response)", |
|
) |
|
app.add_api_route( |
|
"/api/test-api-key", |
|
endpoint=LLMAPIHandler.check_openai_key, |
|
methods=["post"], |
|
summary="LLM APIkey detection", |
|
) |
|
|
|
app.mount( |
|
"/", |
|
StaticFiles(directory="./static/", html=True, follow_symlink=True), |
|
name="static", |
|
) |
|
|
|
set_llm_stream_logfunc(llm_stream_log) |
|
|
|
|
|
def gen_file_modified_time(folder_path): |
|
yield os.path.getmtime(folder_path) |
|
for root, _, files in os.walk(folder_path): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
yield os.path.getmtime(file_path) |
|
|
|
|
|
async def clear_storage(ttl: float = 1800): |
|
storage = pathlib.Path(CONFIG.get("LOCAL_ROOT", "storage")) |
|
logger.info("task `clear_storage` start running") |
|
|
|
while True: |
|
current_time = time.time() |
|
for i in os.listdir(storage): |
|
i = storage / i |
|
try: |
|
last_time = max(gen_file_modified_time(i)) |
|
if current_time - last_time > ttl: |
|
shutil.rmtree(i) |
|
await asyncio.sleep(0) |
|
logger.info(f"Deleted directory: {i}") |
|
except Exception: |
|
logger.exception(f"check {i} error") |
|
await asyncio.sleep(60) |
|
|
|
|
|
def main(): |
|
server_config = CONFIG.get("SERVER_UVICORN", {}) |
|
uvicorn.run(app="__main__:app", **server_config) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|