|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Optional, Any, Dict
|
|
from deepinfra_handler import DeepInfraHandler
|
|
import json
|
|
|
|
app = FastAPI()
|
|
api_handler = DeepInfraHandler()
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages: List[Message]
|
|
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
|
|
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
|
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
|
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
|
stop: Optional[List[str]] = Field(default=[])
|
|
stream: Optional[bool] = Field(default=False)
|
|
|
|
@app.post("/chat/completions")
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
try:
|
|
|
|
params = request.dict()
|
|
|
|
if request.stream:
|
|
|
|
def generate():
|
|
for chunk in api_handler.generate_completion(**params):
|
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream"
|
|
)
|
|
|
|
|
|
response = api_handler.generate_completion(**params)
|
|
return response
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |