File size: 2,663 Bytes
4e2263c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
import requests
from typing import Dict, Any, Generator, Optional

class DeepInfraHandler:
    API_URL = "https://api.deepinfra.com/v1/openai/chat/completions"
    
    def __init__(self):
        self.headers = {
            "Accept": "text/event-stream",
            "Accept-Encoding": "gzip, deflate, br, zstd",
            "Content-Type": "application/json",
            "Connection": "keep-alive",
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36",
        }

    def _prepare_payload(self, **kwargs) -> Dict[str, Any]:
        """Prepare the payload for the API request"""
        return {
            "model": kwargs.get("model"),
            "messages": kwargs.get("messages"),
            "temperature": kwargs.get("temperature", 0.7),
            "max_tokens": kwargs.get("max_tokens", 4096),
            "top_p": kwargs.get("top_p", 1.0),
            "frequency_penalty": kwargs.get("frequency_penalty", 0.0),
            "presence_penalty": kwargs.get("presence_penalty", 0.0),
            "stop": kwargs.get("stop", []),
            "stream": kwargs.get("stream", False)
        }

    def generate_completion(self, **kwargs) -> Any:
        """Generate completion based on streaming preference"""
        payload = self._prepare_payload(**kwargs)
        
        response = requests.post(
            self.API_URL,
            headers=self.headers,
            json=payload,
            stream=payload["stream"]
        )
        
        if payload["stream"]:
            return self._handle_streaming_response(response)
        return self._handle_regular_response(response)

    def _handle_streaming_response(self, response) -> Generator[str, None, None]:
        """Handle streaming response from the API"""
        for line in response.iter_lines(decode_unicode=True):
            if line.startswith("data:"):
                try:
                    content = json.loads(line[5:])
                    if content == "[DONE]":
                        continue
                    delta_content = content.get("choices", [{}])[0].get("delta", {}).get("content")
                    if delta_content:
                        yield delta_content
                except:
                    continue

    def _handle_regular_response(self, response) -> Dict[str, Any]:
        """Handle regular (non-streaming) response from the API"""
        try:
            return response.json()
        except Exception as e:
            raise Exception(f"Error processing response: {str(e)}")