|
|
|
|
|
from __future__ import annotations |
|
|
|
import asyncio |
|
import contextlib |
|
import pathlib |
|
import shutil |
|
import traceback |
|
import uuid |
|
from collections import deque |
|
from datetime import datetime |
|
from enum import Enum |
|
from functools import partial |
|
from typing import Any, Optional |
|
|
|
import fire |
|
import uvicorn |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import StreamingResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from loguru import logger |
|
from metagpt.actions.action import Action |
|
from metagpt.actions.action_output import ActionOutput |
|
from metagpt.config import CONFIG |
|
from metagpt.logs import set_llm_stream_logfunc |
|
from metagpt.schema import Message |
|
from pydantic import BaseModel, Field |
|
|
|
from software_company import RoleRun, SoftwareCompany |
|
|
|
|
|
class QueryAnswerType(Enum): |
|
Query = "Q" |
|
Answer = "A" |
|
|
|
|
|
class SentenceType(Enum): |
|
TEXT = "text" |
|
HIHT = "hint" |
|
ACTION = "action" |
|
ERROR = "error" |
|
|
|
|
|
class MessageStatus(Enum): |
|
COMPLETE = "complete" |
|
|
|
|
|
class SentenceValue(BaseModel): |
|
answer: str |
|
|
|
|
|
class Sentence(BaseModel): |
|
type: str |
|
id: Optional[str] = None |
|
value: SentenceValue |
|
is_finished: Optional[bool] = None |
|
|
|
|
|
class Sentences(BaseModel): |
|
id: Optional[str] = None |
|
action: Optional[str] = None |
|
role: Optional[str] = None |
|
skill: Optional[str] = None |
|
description: Optional[str] = None |
|
timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) |
|
status: str |
|
contents: list[dict] |
|
|
|
|
|
class NewMsg(BaseModel): |
|
"""Chat with MetaGPT""" |
|
|
|
query: str = Field(description="Problem description") |
|
config: dict[str, Any] = Field(description="Configuration information") |
|
|
|
|
|
class ErrorInfo(BaseModel): |
|
error: str = None |
|
traceback: str = None |
|
|
|
|
|
class ThinkActStep(BaseModel): |
|
id: str |
|
status: str |
|
title: str |
|
timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) |
|
description: str |
|
content: Sentence = None |
|
|
|
|
|
class ThinkActPrompt(BaseModel): |
|
message_id: int = None |
|
timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f%z")) |
|
step: ThinkActStep = None |
|
skill: Optional[str] = None |
|
role: Optional[str] = None |
|
|
|
def update_think(self, tc_id, action: Action): |
|
self.step = ThinkActStep( |
|
id=str(tc_id), |
|
status="running", |
|
title=action.desc, |
|
description=action.desc, |
|
) |
|
|
|
def update_act(self, message: ActionOutput | str, is_finished: bool = True): |
|
if is_finished: |
|
self.step.status = "finish" |
|
self.step.content = Sentence( |
|
type=SentenceType.TEXT.value, |
|
id=str(1), |
|
value=SentenceValue(answer=message.content if is_finished else message), |
|
is_finished=is_finished, |
|
) |
|
|
|
@staticmethod |
|
def guid32(): |
|
return str(uuid.uuid4()).replace("-", "")[0:32] |
|
|
|
@property |
|
def prompt(self): |
|
return self.json(exclude_unset=True) |
|
|
|
|
|
class MessageJsonModel(BaseModel): |
|
steps: list[Sentences] |
|
qa_type: str |
|
created_at: datetime = Field(default_factory=datetime.now) |
|
query_time: datetime = Field(default_factory=datetime.now) |
|
answer_time: datetime = Field(default_factory=datetime.now) |
|
score: Optional[int] = None |
|
feedback: Optional[str] = None |
|
|
|
def add_think_act(self, think_act_prompt: ThinkActPrompt): |
|
s = Sentences( |
|
action=think_act_prompt.step.title, |
|
skill=think_act_prompt.skill, |
|
description=think_act_prompt.step.description, |
|
timestamp=think_act_prompt.timestamp, |
|
status=think_act_prompt.step.status, |
|
contents=[think_act_prompt.step.content.dict()], |
|
) |
|
self.steps.append(s) |
|
|
|
@property |
|
def prompt(self): |
|
return self.json(exclude_unset=True) |
|
|
|
|
|
async def create_message(req_model: NewMsg, request: Request): |
|
""" |
|
Session message stream |
|
""" |
|
tc_id = 0 |
|
try: |
|
exclude_keys = CONFIG.get("SERVER_METAGPT_CONFIG_EXCLUDE", []) |
|
config = {k.upper(): v for k, v in req_model.config.items() if k not in exclude_keys} |
|
set_context(config, uuid.uuid4().hex) |
|
|
|
msg_queue = deque() |
|
CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None |
|
|
|
role = SoftwareCompany() |
|
role.recv(message=Message(content=req_model.query)) |
|
answer = MessageJsonModel( |
|
steps=[ |
|
Sentences( |
|
contents=[ |
|
Sentence( |
|
type=SentenceType.TEXT.value, value=SentenceValue(answer=req_model.query), is_finished=True |
|
) |
|
], |
|
status=MessageStatus.COMPLETE.value, |
|
) |
|
], |
|
qa_type=QueryAnswerType.Answer.value, |
|
) |
|
|
|
task = None |
|
|
|
async def stop_if_disconnect(): |
|
while not await request.is_disconnected(): |
|
await asyncio.sleep(1) |
|
|
|
if task is None: |
|
return |
|
|
|
if not task.done(): |
|
task.cancel() |
|
logger.info(f"cancel task {task}") |
|
|
|
asyncio.create_task(stop_if_disconnect()) |
|
|
|
while True: |
|
tc_id += 1 |
|
if await request.is_disconnected(): |
|
return |
|
think_result: RoleRun = await role.think() |
|
if not think_result: |
|
break |
|
|
|
think_act_prompt = ThinkActPrompt(role=think_result.role.profile) |
|
think_act_prompt.update_think(tc_id, think_result) |
|
yield think_act_prompt.prompt + "\n\n" |
|
task = asyncio.create_task(role.act()) |
|
|
|
while not await request.is_disconnected(): |
|
if msg_queue: |
|
think_act_prompt.update_act(msg_queue.pop(), False) |
|
yield think_act_prompt.prompt + "\n\n" |
|
continue |
|
|
|
if task.done(): |
|
break |
|
|
|
await asyncio.sleep(0.5) |
|
else: |
|
task.cancel() |
|
return |
|
|
|
act_result = await task |
|
think_act_prompt.update_act(act_result) |
|
yield think_act_prompt.prompt + "\n\n" |
|
answer.add_think_act(think_act_prompt) |
|
yield answer.prompt + "\n\n" |
|
except asyncio.CancelledError: |
|
task.cancel() |
|
except Exception as ex: |
|
description = str(ex) |
|
answer = traceback.format_exc() |
|
step = ThinkActStep( |
|
id=tc_id, |
|
status="failed", |
|
title=description, |
|
description=description, |
|
content=Sentence(type=SentenceType.ERROR.value, id=1, value=SentenceValue(answer=answer), is_finished=True), |
|
) |
|
think_act_prompt = ThinkActPrompt(step=step) |
|
yield think_act_prompt.prompt + "\n\n" |
|
finally: |
|
CONFIG.WORKSPACE_PATH: pathlib.Path |
|
if CONFIG.WORKSPACE_PATH.exists(): |
|
shutil.rmtree(CONFIG.WORKSPACE_PATH) |
|
|
|
|
|
default_llm_stream_log = partial(print, end="") |
|
|
|
|
|
def llm_stream_log(msg): |
|
with contextlib.suppress(): |
|
CONFIG._get("LLM_STREAM_LOG", default_llm_stream_log)(msg) |
|
|
|
|
|
def set_context(context, uid): |
|
context["WORKSPACE_PATH"] = pathlib.Path("workspace", uid) |
|
for old, new in (("DEPLOYMENT_ID", "DEPLOYMENT_NAME"), ("OPENAI_API_BASE", "OPENAI_BASE_URL")): |
|
if old in context and new not in context: |
|
context[new] = context[old] |
|
CONFIG.set_context(context) |
|
return context |
|
|
|
|
|
class ChatHandler: |
|
@staticmethod |
|
async def create_message(req_model: NewMsg, request: Request): |
|
"""Message stream, using SSE.""" |
|
event = create_message(req_model, request) |
|
headers = {"Cache-Control": "no-cache", "Connection": "keep-alive"} |
|
return StreamingResponse(event, headers=headers, media_type="text/event-stream") |
|
|
|
|
|
app = FastAPI() |
|
|
|
app.mount( |
|
"/storage", |
|
StaticFiles(directory="./storage/"), |
|
name="storage", |
|
) |
|
|
|
app.add_api_route( |
|
"/api/messages", |
|
endpoint=ChatHandler.create_message, |
|
methods=["post"], |
|
summary="Session message sending (streaming response)", |
|
) |
|
|
|
|
|
app.mount( |
|
"/", |
|
StaticFiles(directory="./static/", html=True, follow_symlink=True), |
|
name="static", |
|
) |
|
|
|
|
|
set_llm_stream_logfunc(llm_stream_log) |
|
|
|
|
|
def main(): |
|
server_config = CONFIG.get("SERVER_UVICORN", {}) |
|
uvicorn.run(app="__main__:app", **server_config) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|