Spaces:
Running
Running
File size: 3,854 Bytes
2b872cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import requests
import json
from typing import Generator
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
import uvicorn
from dotenv import load_dotenv
import os
load_dotenv()
app = FastAPI()
class v1:
"""
A class to interact with the Prefind AI API.
"""
AVAILABLE_MODELS = ["llama", "claude"]
def __init__(
self,
model: str = "claude",
timeout: int = 30,
proxies: dict = {},
):
"""
Initializes the Prefind AI API with given parameters.
Args:
model (str, optional): The AI model to use for text generation. Defaults to "claude".
Options: "llama", "claude".
timeout (int, optional): Http request timeout. Defaults to 30.
proxies (dict, optional): Http request proxies. Defaults to {}.
"""
if model not in self.AVAILABLE_MODELS:
raise ValueError(f"Model '{model}' is not supported. Choose from {self.AVAILABLE_MODELS}.")
self.session = requests.Session()
self.api_endpoint = os.getenv("API_ENDPOINT")
self.timeout = timeout
self.model = model
self.device_token = self.get_device_token()
self.session.headers.update(
{
"Content-Type": "application/json",
"Accept": "text/event-stream",
}
)
self.session.proxies = proxies
def get_device_token(self) -> str:
device_token_url = os.getenv("DEVICE_TOKEN_URL")
headers = {"Content-Type": "application/json; charset=utf-8"}
data = {}
response = requests.post(
device_token_url, headers=headers, data=json.dumps(data)
)
if response.status_code == 200:
device_token_data = response.json()
return device_token_data["sessionToken"]
else:
raise Exception(
f"Failed to get device token - ({response.status_code}, {response.reason}) - {response.text}"
)
def ask(self, prompt: str) -> Generator[str, None, None]:
search_data = {"query": prompt, "deviceToken": self.device_token}
response = self.session.post(
self.api_endpoint, json=search_data, stream=True, timeout=self.timeout
)
if not response.ok:
raise Exception(
f"Failed to generate response - ({response.status_code}, {response.reason}) - {response.text}"
)
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
data_str = line[6:]
data = json.loads(data_str)
if data['type'] == 'chunk':
model = data['model']
if (self.model == "llama" and model == 'OPENROUTER_LLAMA_3') or \
(self.model == "claude" and model == 'OPENROUTER_CLAUDE'):
content = data['chunk']['content']
if content:
yield content
def chat(self, prompt: str) -> Generator[str, None, None]:
"""Stream responses as string chunks"""
return self.ask(prompt)
@app.get("/Search/pro")
async def chat(prompt: str, model: str = "claude"):
if model not in v1.AVAILABLE_MODELS:
raise HTTPException(status_code=400, detail=f"Model '{model}' is not supported. Choose from {v1.AVAILABLE_MODELS}.")
ai = v1(model=model)
def response_generator():
for chunk in ai.chat(prompt):
yield f"data: {chunk}\n\n"
return StreamingResponse(response_generator(), media_type="text/event-stream")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |