from flask import Flask, request, Response, jsonify, stream_with_context from flask_cors import CORS import json from typegpt_api import generate, model_mapping, simplified_models from api_info import developer_info, model_providers app = Flask(__name__) # Set up CORS middleware if needed CORS(app, resources={ r"/*": { "origins": "*", "allow_credentials": True, "methods": ["*"], "headers": ["*"] } }) @app.route("/health_check", methods=['GET']) def health_check(): return jsonify({"status": "OK"}) @app.route("/models", methods=['GET']) def get_models(): try: response = { "object": "list", "data": [] } for provider, info in model_providers.items(): for model in info["models"]: response["data"].append({ "id": model, "object": "model", "provider": provider, "description": info["description"] }) return jsonify(response) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/chat/completions", methods=['POST']) def chat_completions(): # Receive the JSON payload try: body = request.get_json() except Exception as e: return jsonify({"error": "Invalid JSON payload"}), 400 # Extract parameters model = body.get("model") messages = body.get("messages") temperature = body.get("temperature", 0.7) top_p = body.get("top_p", 1.0) n = body.get("n", 1) stream = body.get("stream", False) stop = body.get("stop") max_tokens = body.get("max_tokens") presence_penalty = body.get("presence_penalty", 0.0) frequency_penalty = body.get("frequency_penalty", 0.0) logit_bias = body.get("logit_bias") user = body.get("user") timeout = 30 # or set based on your preference # Validate required parameters if not model: return jsonify({"error": "The 'model' parameter is required."}), 400 if not messages: return jsonify({"error": "The 'messages' parameter is required."}), 400 # Call the generate function try: if stream: def generate_stream(): response = generate( model=model, messages=messages, temperature=temperature, top_p=top_p, n=n, stream=True, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, timeout=timeout, ) for chunk in response: yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n" return Response( stream_with_context(generate_stream()), mimetype="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Transfer-Encoding": "chunked" } ) else: response = generate( model=model, messages=messages, temperature=temperature, top_p=top_p, n=n, stream=False, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, timeout=timeout, ) return jsonify(response) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/developer_info", methods=['GET']) def get_developer_info(): return jsonify(developer_info) if __name__ == "__main__": app.run(host="0.0.0.0", port=8000)