Spaces:
Running
Running
File size: 4,432 Bytes
501c69f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json
from typegpt_api import generate, model_mapping, simplified_models
from api_info import developer_info, model_providers
app = FastAPI()
# Set up CORS middleware if needed
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health_check")
async def health_check():
return {"status": "OK"}
@app.get("/models")
async def get_models():
try:
response = {
"object": "list",
"data": []
}
for provider, info in model_providers.items():
for model in info["models"]:
response["data"].append({
"id": model,
"object": "model",
"provider": provider,
"description": info["description"]
})
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/chat/completions")
async def chat_completions(request: Request):
# Receive the JSON payload
try:
body = await request.json()
except Exception as e:
return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400)
# Extract parameters
model = body.get("model")
messages = body.get("messages")
temperature = body.get("temperature", 0.7)
top_p = body.get("top_p", 1.0)
n = body.get("n", 1)
stream = body.get("stream", False)
stop = body.get("stop")
max_tokens = body.get("max_tokens")
presence_penalty = body.get("presence_penalty", 0.0)
frequency_penalty = body.get("frequency_penalty", 0.0)
logit_bias = body.get("logit_bias")
user = body.get("user")
timeout = 30 # or set based on your preference
# Validate required parameters
if not model:
return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400)
if not messages:
return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400)
# Call the generate function
try:
if stream:
async def generate_stream():
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=True,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
else:
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.get("/developer_info")
async def get_developer_info():
return JSONResponse(content=developer_info)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) |