|
import gevent.pywsgi
|
|
from gevent import monkey;monkey.patch_all()
|
|
from flask import Flask, request, Response
|
|
import argparse
|
|
import requests
|
|
import random
|
|
import string
|
|
import time
|
|
import json
|
|
import os
|
|
|
|
app = Flask(__name__)
|
|
|
|
parser = argparse.ArgumentParser(description="An example of Qwen demo with a similar API to OAI.")
|
|
parser.add_argument("--host", type=str, help="Set the ip address.(default: 0.0.0.0)", default='0.0.0.0')
|
|
parser.add_argument("--port", type=int, help="Set the port.(default: 7860)", default=7860)
|
|
args = parser.parse_args()
|
|
|
|
base_url = os.getenv('MODEL_BASE_URL')
|
|
|
|
@app.route("/", methods=["GET"])
|
|
def index():
|
|
return Response(f'QW1_5 OpenAI Compatible API<br><br>'+
|
|
f'Set "{os.getenv("SPACE_URL")}/api" as proxy (or API Domain) in your Chatbot.<br><br>'+
|
|
f'The complete API is: {os.getenv("SPACE_URL")}/api/v1/chat/completions')
|
|
|
|
@app.route("/api/v1/chat/completions", methods=["POST", "OPTIONS"])
|
|
@app.route("/v1/chat/completions", methods=["POST", "OPTIONS"])
|
|
def chat_completions():
|
|
|
|
if request.method == "OPTIONS":
|
|
return Response(
|
|
headers={
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Headers": "*",
|
|
}
|
|
)
|
|
|
|
data = request.get_json()
|
|
|
|
|
|
system = "You are a helpful assistant."
|
|
chat_history = []
|
|
prompt = ""
|
|
|
|
if "messages" in data:
|
|
messages = data["messages"]
|
|
message_size = len(messages)
|
|
|
|
prompt = messages[-1].get("content")
|
|
for i in range(message_size - 1):
|
|
role_this = messages[i].get("role")
|
|
role_next = messages[i + 1].get("role")
|
|
if role_this == "system":
|
|
system = messages[i].get("content")
|
|
elif role_this == "user":
|
|
if role_next == "assistant":
|
|
chat_history.append(
|
|
[messages[i].get("content"), messages[i + 1].get("content")]
|
|
)
|
|
else:
|
|
chat_history.append([messages[i].get("content"), " "])
|
|
|
|
|
|
|
|
|
|
|
|
fn_index = 0
|
|
|
|
|
|
chars = string.ascii_lowercase + string.digits
|
|
session_hash = "".join(random.choice(chars) for _ in range(11))
|
|
|
|
json_prompt = {
|
|
"data": [prompt, chat_history, system],
|
|
"fn_index": fn_index,
|
|
"session_hash": session_hash,
|
|
}
|
|
|
|
def generate():
|
|
response = requests.post(f"{base_url}/queue/join", json=json_prompt)
|
|
url = f"{base_url}/queue/data?session_hash={session_hash}"
|
|
data = requests.get(url, stream=True)
|
|
|
|
time_now = int(time.time())
|
|
|
|
for line in data.iter_lines():
|
|
if line:
|
|
decoded_line = line.decode("utf-8")
|
|
json_line = json.loads(decoded_line[6:])
|
|
if json_line["msg"] == "process_starts":
|
|
res_data = gen_res_data({}, time_now=time_now, start=True)
|
|
yield f"data: {json.dumps(res_data)}\n\n"
|
|
elif json_line["msg"] == "process_generating":
|
|
res_data = gen_res_data(json_line, time_now=time_now)
|
|
yield f"data: {json.dumps(res_data)}\n\n"
|
|
elif json_line["msg"] == "process_completed":
|
|
yield "data: [DONE]"
|
|
|
|
return Response(
|
|
generate(),
|
|
mimetype="text/event-stream",
|
|
headers={
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Headers": "*",
|
|
},
|
|
)
|
|
|
|
|
|
def gen_res_data(data, time_now=0, start=False):
|
|
res_data = {
|
|
"id": "chatcmpl",
|
|
"object": "chat.completion.chunk",
|
|
"created": time_now,
|
|
"model": "qwen1_5",
|
|
"choices": [{"index": 0, "finish_reason": None}],
|
|
}
|
|
|
|
if start:
|
|
res_data["choices"][0]["delta"] = {"role": "assistant", "content": ""}
|
|
else:
|
|
chat_pair = data["output"]["data"][1]
|
|
if chat_pair == []:
|
|
res_data["choices"][0]["finish_reason"] = "stop"
|
|
else:
|
|
res_data["choices"][0]["delta"] = {"content": chat_pair[-1][-1]}
|
|
return res_data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
gevent.pywsgi.WSGIServer((args.host, args.port), app).serve_forever()
|
|
|