test_api / flask_app.py
API-Handler's picture
Upload 10 files
501c69f verified
from flask import Flask, request, Response, jsonify, stream_with_context
from flask_cors import CORS
import json
from typegpt_api import generate, model_mapping, simplified_models
from api_info import developer_info, model_providers
app = Flask(__name__)
# Set up CORS middleware if needed
CORS(app, resources={
r"/*": {
"origins": "*",
"allow_credentials": True,
"methods": ["*"],
"headers": ["*"]
}
})
@app.route("/health_check", methods=['GET'])
def health_check():
return jsonify({"status": "OK"})
@app.route("/models", methods=['GET'])
def get_models():
try:
response = {
"object": "list",
"data": []
}
for provider, info in model_providers.items():
for model in info["models"]:
response["data"].append({
"id": model,
"object": "model",
"provider": provider,
"description": info["description"]
})
return jsonify(response)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/chat/completions", methods=['POST'])
def chat_completions():
# Receive the JSON payload
try:
body = request.get_json()
except Exception as e:
return jsonify({"error": "Invalid JSON payload"}), 400
# Extract parameters
model = body.get("model")
messages = body.get("messages")
temperature = body.get("temperature", 0.7)
top_p = body.get("top_p", 1.0)
n = body.get("n", 1)
stream = body.get("stream", False)
stop = body.get("stop")
max_tokens = body.get("max_tokens")
presence_penalty = body.get("presence_penalty", 0.0)
frequency_penalty = body.get("frequency_penalty", 0.0)
logit_bias = body.get("logit_bias")
user = body.get("user")
timeout = 30 # or set based on your preference
# Validate required parameters
if not model:
return jsonify({"error": "The 'model' parameter is required."}), 400
if not messages:
return jsonify({"error": "The 'messages' parameter is required."}), 400
# Call the generate function
try:
if stream:
def generate_stream():
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=True,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return Response(
stream_with_context(generate_stream()),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
else:
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
return jsonify(response)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/developer_info", methods=['GET'])
def get_developer_info():
return jsonify(developer_info)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000)