File size: 4,432 Bytes
501c69f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json

from typegpt_api import generate, model_mapping, simplified_models
from api_info import developer_info, model_providers

app = FastAPI()

# Set up CORS middleware if needed
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/health_check")
async def health_check():
    return {"status": "OK"}

@app.get("/models")
async def get_models():
    try:
        response = {
            "object": "list",
            "data": []
        }
        for provider, info in model_providers.items():
            for model in info["models"]:
                response["data"].append({
                    "id": model,
                    "object": "model",
                    "provider": provider,
                    "description": info["description"]
                })
                
        return JSONResponse(content=response)
    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)

@app.post("/chat/completions")
async def chat_completions(request: Request):
    # Receive the JSON payload
    try:
        body = await request.json()
    except Exception as e:
        return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400)

    # Extract parameters
    model = body.get("model")
    messages = body.get("messages")
    temperature = body.get("temperature", 0.7)
    top_p = body.get("top_p", 1.0)
    n = body.get("n", 1)
    stream = body.get("stream", False)
    stop = body.get("stop")
    max_tokens = body.get("max_tokens")
    presence_penalty = body.get("presence_penalty", 0.0)
    frequency_penalty = body.get("frequency_penalty", 0.0)
    logit_bias = body.get("logit_bias")
    user = body.get("user")
    timeout = 30  # or set based on your preference

    # Validate required parameters
    if not model:
        return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400)
    if not messages:
        return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400)

    # Call the generate function
    try:
        if stream:
            async def generate_stream():
                response = generate(
                    model=model,
                    messages=messages,
                    temperature=temperature,
                    top_p=top_p,
                    n=n,
                    stream=True,
                    stop=stop,
                    max_tokens=max_tokens,
                    presence_penalty=presence_penalty,
                    frequency_penalty=frequency_penalty,
                    logit_bias=logit_bias,
                    user=user,
                    timeout=timeout,
                )
                
                for chunk in response:
                    yield f"data: {json.dumps(chunk)}\n\n"
                yield "data: [DONE]\n\n"

            return StreamingResponse(
                generate_stream(),
                media_type="text/event-stream",
                headers={
                    "Cache-Control": "no-cache",
                    "Connection": "keep-alive",
                    "Transfer-Encoding": "chunked"
                }
            )
        else:
            response = generate(
                model=model,
                messages=messages,
                temperature=temperature,
                top_p=top_p,
                n=n,
                stream=False,
                stop=stop,
                max_tokens=max_tokens,
                presence_penalty=presence_penalty,
                frequency_penalty=frequency_penalty,
                logit_bias=logit_bias,
                user=user,
                timeout=timeout,
            )
            return JSONResponse(content=response)
    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)

@app.get("/developer_info")
async def get_developer_info():
    return JSONResponse(content=developer_info)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)