File size: 3,854 Bytes
2b872cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import requests
import json
from typing import Generator
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
import uvicorn
from dotenv import load_dotenv
import os

load_dotenv()

app = FastAPI()

class v1:
    """
    A class to interact with the Prefind AI API.
    """

    AVAILABLE_MODELS = ["llama", "claude"]

    def __init__(
        self,
        model: str = "claude",
        timeout: int = 30,
        proxies: dict = {},
    ):
        """
        Initializes the Prefind AI API with given parameters.

        Args:
            model (str, optional): The AI model to use for text generation. Defaults to "claude". 
                                    Options: "llama", "claude".
            timeout (int, optional): Http request timeout. Defaults to 30.
            proxies (dict, optional): Http request proxies. Defaults to {}.
        """
        if model not in self.AVAILABLE_MODELS:
            raise ValueError(f"Model '{model}' is not supported. Choose from {self.AVAILABLE_MODELS}.")

        self.session = requests.Session()
        self.api_endpoint = os.getenv("API_ENDPOINT")
        self.timeout = timeout
        self.model = model
        self.device_token = self.get_device_token()

        self.session.headers.update(
            {
                "Content-Type": "application/json",
                "Accept": "text/event-stream",
            }
        )
        self.session.proxies = proxies

    def get_device_token(self) -> str:
        device_token_url = os.getenv("DEVICE_TOKEN_URL")
        headers = {"Content-Type": "application/json; charset=utf-8"}
        data = {}
        response = requests.post(
            device_token_url, headers=headers, data=json.dumps(data)
        )

        if response.status_code == 200:
            device_token_data = response.json()
            return device_token_data["sessionToken"]
        else:
            raise Exception(
                f"Failed to get device token - ({response.status_code}, {response.reason}) - {response.text}"
            )

    def ask(self, prompt: str) -> Generator[str, None, None]:
        search_data = {"query": prompt, "deviceToken": self.device_token}

        response = self.session.post(
            self.api_endpoint, json=search_data, stream=True, timeout=self.timeout
        )
        if not response.ok:
            raise Exception(
                f"Failed to generate response - ({response.status_code}, {response.reason}) - {response.text}"
            )

        for line in response.iter_lines(decode_unicode=True):
            if line:
                if line.startswith("data: "):
                    data_str = line[6:]
                    data = json.loads(data_str)
                    if data['type'] == 'chunk':
                        model = data['model']
                        if (self.model == "llama" and model == 'OPENROUTER_LLAMA_3') or \
                           (self.model == "claude" and model == 'OPENROUTER_CLAUDE'):
                            content = data['chunk']['content']
                            if content:
                                yield content

    def chat(self, prompt: str) -> Generator[str, None, None]:
        """Stream responses as string chunks"""
        return self.ask(prompt)


@app.get("/Search/pro")
async def chat(prompt: str, model: str = "claude"):
    if model not in v1.AVAILABLE_MODELS:
        raise HTTPException(status_code=400, detail=f"Model '{model}' is not supported. Choose from {v1.AVAILABLE_MODELS}.")

    ai = v1(model=model)
    
    def response_generator():
        for chunk in ai.chat(prompt):
            yield f"data: {chunk}\n\n"
    
    return StreamingResponse(response_generator(), media_type="text/event-stream")


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)