File size: 4,024 Bytes
5d6c1fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from typing import Optional, Dict, List, Any
import os
import requests
import json
from dotenv import load_dotenv
from dataclasses import dataclass


load_dotenv()


@dataclass
class GeminiResponse:
    content: str


class GeminiProvider:
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.base_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
        
    def chat(self, messages: List[Dict[str, Any]]) -> GeminiResponse:
        # Convert messages to Gemini format
        gemini_messages = []
        for msg in messages:
            # Handle both dict and LangChain message objects
            if isinstance(msg, BaseMessage):
                role = "user" if isinstance(msg, HumanMessage) else "model"
                content = msg.content
            else:
                role = "user" if msg["role"] == "human" else "model"
                content = msg["content"]
                
            gemini_messages.append({
                "role": role,
                "parts": [{"text": content}]
            })
        
        # Prepare the request
        headers = {
            "Content-Type": "application/json"
        }
        
        params = {
            "key": self.api_key
        }
        
        data = {
            "contents": gemini_messages,
            "generationConfig": {
                "temperature": 0.7,
                "topP": 0.8,
                "topK": 40,
                "maxOutputTokens": 2048,
            }
        }
        
        try:
            response = requests.post(
                self.base_url,
                headers=headers,
                params=params,
                json=data,
                
            )
            response.raise_for_status()
            
            result = response.json()
            if "candidates" in result and len(result["candidates"]) > 0:
                return GeminiResponse(content=result["candidates"][0]["content"]["parts"][0]["text"])
            else:
                raise Exception("No response generated")
                
        except Exception as e:
            raise Exception(f"Error calling Gemini API: {str(e)}")

    def invoke(self, messages: List[BaseMessage], **kwargs) -> GeminiResponse:
        return self.chat(messages)

    def generate(self, prompts, **kwargs) -> GeminiResponse:
        if isinstance(prompts, str):
            return self.invoke([HumanMessage(content=prompts)])
        elif isinstance(prompts, list):
            return self.invoke([HumanMessage(content=prompts[0])])
        raise ValueError("Unsupported prompt format")

class LLMProvider:
    def __init__(self):
        self.providers: Dict[str, Any] = {}
        self._setup_providers()

    def _setup_providers(self):
        
           # Google Gemini
        if google_key := os.getenv('GOOGLE_API_KEY'):
            self.providers['Gemini'] = GeminiProvider(api_key=google_key)

        
        # Anthropic
        if anthropic_key := os.getenv('ANTHROPIC_API_KEY'):
            self.providers['Claude'] = ChatAnthropic(
                api_key=anthropic_key,
                model_name="claude-3-5-sonnet-20241022",
                
            )

        # OpenAI
        if openai_key := os.getenv('OPENAI_API_KEY'):
            self.providers['ChatGPT'] = ChatOpenAI(
                api_key=openai_key,
                model_name="gpt-4o-2024-11-20"
            )

     
       

    def get_available_providers(self) -> list[str]:
        """Return list of available provider names"""
        return list(self.providers.keys())

    def get_provider(self, name: str) -> Optional[Any]:
        """Get LLM provider by name"""
        return self.providers.get(name)