File size: 3,237 Bytes
fe5c39d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time    : 2024/3/27 9:44
@Author  : leiwu30
@File    : stream_output_via_api.py
@Description    : Stream log information and communicate over the network via web api.
"""
import asyncio
import json
import socket
import threading
from contextvars import ContextVar

from flask import Flask, Response, jsonify, request, send_from_directory

from metagpt.const import TUTORIAL_PATH
from metagpt.logs import logger, set_llm_stream_logfunc
from metagpt.roles.tutorial_assistant import TutorialAssistant
from metagpt.utils.stream_pipe import StreamPipe

app = Flask(__name__)


def stream_pipe_log(content):
    print(content, end="")
    stream_pipe = stream_pipe_var.get(None)
    if stream_pipe:
        stream_pipe.set_message(content)


def write_tutorial(message):
    async def main(idea, stream_pipe):
        stream_pipe_var.set(stream_pipe)
        role = TutorialAssistant()
        await role.run(idea)

    def thread_run(idea: str, stream_pipe: StreamPipe = None):
        """
        Convert asynchronous function to thread function
        """
        asyncio.run(main(idea, stream_pipe))

    stream_pipe = StreamPipe()
    thread = threading.Thread(
        target=thread_run,
        args=(
            message["content"],
            stream_pipe,
        ),
    )
    thread.start()

    while thread.is_alive():
        msg = stream_pipe.get_message()
        yield stream_pipe.msg2stream(msg)


@app.route("/v1/chat/completions", methods=["POST"])
def completions():
    """
    data: {
        "model": "write_tutorial",
        "stream": true,
        "messages": [
            {
                "role": "user",
                "content": "Write a tutorial about MySQL"
            }
        ]
    }
    """

    data = json.loads(request.data)
    logger.info(json.dumps(data, indent=4, ensure_ascii=False))

    # Non-streaming interfaces are not supported yet
    stream_type = True if data.get("stream") else False
    if not stream_type:
        return jsonify({"status": 400, "msg": "Non-streaming requests are not supported, please use `stream=True`."})

    # Only accept the last user information
    # openai['model'] ~ MetaGPT['agent']
    last_message = data["messages"][-1]
    model = data["model"]

    # write_tutorial
    if model == "write_tutorial":
        return Response(write_tutorial(last_message), mimetype="text/plain")
    else:
        return jsonify({"status": 400, "msg": "No suitable agent found."})


@app.route("/download/<path:filename>")
def download_file(filename):
    return send_from_directory(TUTORIAL_PATH, filename, as_attachment=True)


if __name__ == "__main__":
    """
    curl https://$server_address:$server_port/v1/chat/completions -X POST -d '{
        "model": "write_tutorial",
        "stream": true,
        "messages": [
          {
               "role": "user",
               "content": "Write a tutorial about MySQL"
          }
        ]
    }'
    """
    server_port = 7860
    server_address = socket.gethostbyname(socket.gethostname())

    set_llm_stream_logfunc(stream_pipe_log)
    stream_pipe_var: ContextVar[StreamPipe] = ContextVar("stream_pipe")
    app.run(port=server_port, host=server_address)