Spaces:
Sleeping
Sleeping
from flask import Flask, request, Response, jsonify, stream_with_context | |
from flask_cors import CORS | |
import json | |
from typegpt_api import generate, model_mapping, simplified_models | |
from api_info import developer_info, model_providers | |
app = Flask(__name__) | |
# Set up CORS middleware if needed | |
CORS(app, resources={ | |
r"/*": { | |
"origins": "*", | |
"allow_credentials": True, | |
"methods": ["*"], | |
"headers": ["*"] | |
} | |
}) | |
def health_check(): | |
return jsonify({"status": "OK"}) | |
def get_models(): | |
try: | |
response = { | |
"object": "list", | |
"data": [] | |
} | |
for provider, info in model_providers.items(): | |
for model in info["models"]: | |
response["data"].append({ | |
"id": model, | |
"object": "model", | |
"provider": provider, | |
"description": info["description"] | |
}) | |
return jsonify(response) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def chat_completions(): | |
# Receive the JSON payload | |
try: | |
body = request.get_json() | |
except Exception as e: | |
return jsonify({"error": "Invalid JSON payload"}), 400 | |
# Extract parameters | |
model = body.get("model") | |
messages = body.get("messages") | |
temperature = body.get("temperature", 0.7) | |
top_p = body.get("top_p", 1.0) | |
n = body.get("n", 1) | |
stream = body.get("stream", False) | |
stop = body.get("stop") | |
max_tokens = body.get("max_tokens") | |
presence_penalty = body.get("presence_penalty", 0.0) | |
frequency_penalty = body.get("frequency_penalty", 0.0) | |
logit_bias = body.get("logit_bias") | |
user = body.get("user") | |
timeout = 30 # or set based on your preference | |
# Validate required parameters | |
if not model: | |
return jsonify({"error": "The 'model' parameter is required."}), 400 | |
if not messages: | |
return jsonify({"error": "The 'messages' parameter is required."}), 400 | |
# Call the generate function | |
try: | |
if stream: | |
def generate_stream(): | |
response = generate( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stream=True, | |
stop=stop, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
logit_bias=logit_bias, | |
user=user, | |
timeout=timeout, | |
) | |
for chunk in response: | |
yield f"data: {json.dumps(chunk)}\n\n" | |
yield "data: [DONE]\n\n" | |
return Response( | |
stream_with_context(generate_stream()), | |
mimetype="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"Transfer-Encoding": "chunked" | |
} | |
) | |
else: | |
response = generate( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stream=False, | |
stop=stop, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
logit_bias=logit_bias, | |
user=user, | |
timeout=timeout, | |
) | |
return jsonify(response) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def get_developer_info(): | |
return jsonify(developer_info) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=8000) |