sivan22 commited on
Commit
5d6c1fa
·
verified ·
1 Parent(s): 165cf08

Update llm_providers.py

Browse files
Files changed (1) hide show
  1. llm_providers.py +128 -133
llm_providers.py CHANGED
@@ -1,133 +1,128 @@
1
- from langchain_anthropic import ChatAnthropic
2
- from langchain_openai import ChatOpenAI
3
- from langchain_ollama import ChatOllama
4
- from langchain_core.language_models.base import BaseLanguageModel
5
- from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
6
- from typing import Optional, Dict, List, Any
7
- import os
8
- import requests
9
- import json
10
- from dotenv import load_dotenv
11
- from dataclasses import dataclass
12
-
13
-
14
- load_dotenv()
15
-
16
-
17
- @dataclass
18
- class GeminiResponse:
19
- content: str
20
-
21
-
22
- class GeminiProvider:
23
- def __init__(self, api_key: str):
24
- self.api_key = api_key
25
- self.base_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
26
-
27
- def chat(self, messages: List[Dict[str, Any]]) -> GeminiResponse:
28
- # Convert messages to Gemini format
29
- gemini_messages = []
30
- for msg in messages:
31
- # Handle both dict and LangChain message objects
32
- if isinstance(msg, BaseMessage):
33
- role = "user" if isinstance(msg, HumanMessage) else "model"
34
- content = msg.content
35
- else:
36
- role = "user" if msg["role"] == "human" else "model"
37
- content = msg["content"]
38
-
39
- gemini_messages.append({
40
- "role": role,
41
- "parts": [{"text": content}]
42
- })
43
-
44
- # Prepare the request
45
- headers = {
46
- "Content-Type": "application/json"
47
- }
48
-
49
- params = {
50
- "key": self.api_key
51
- }
52
-
53
- data = {
54
- "contents": gemini_messages,
55
- "generationConfig": {
56
- "temperature": 0.7,
57
- "topP": 0.8,
58
- "topK": 40,
59
- "maxOutputTokens": 2048,
60
- }
61
- }
62
-
63
- try:
64
- response = requests.post(
65
- self.base_url,
66
- headers=headers,
67
- params=params,
68
- json=data,
69
- verify='C:\\ProgramData\\NetFree\\CA\\netfree-ca-bundle-curl.crt'
70
- )
71
- response.raise_for_status()
72
-
73
- result = response.json()
74
- if "candidates" in result and len(result["candidates"]) > 0:
75
- return GeminiResponse(content=result["candidates"][0]["content"]["parts"][0]["text"])
76
- else:
77
- raise Exception("No response generated")
78
-
79
- except Exception as e:
80
- raise Exception(f"Error calling Gemini API: {str(e)}")
81
-
82
- def invoke(self, messages: List[BaseMessage], **kwargs) -> GeminiResponse:
83
- return self.chat(messages)
84
-
85
- def generate(self, prompts, **kwargs) -> GeminiResponse:
86
- if isinstance(prompts, str):
87
- return self.invoke([HumanMessage(content=prompts)])
88
- elif isinstance(prompts, list):
89
- return self.invoke([HumanMessage(content=prompts[0])])
90
- raise ValueError("Unsupported prompt format")
91
-
92
- class LLMProvider:
93
- def __init__(self):
94
- self.providers: Dict[str, Any] = {}
95
- self._setup_providers()
96
-
97
- def _setup_providers(self):
98
- os.environ['REQUESTS_CA_BUNDLE'] = 'C:\\ProgramData\\NetFree\\CA\\netfree-ca-bundle-curl.crt'
99
-
100
- # Google Gemini
101
- if google_key := os.getenv('GOOGLE_API_KEY'):
102
- self.providers['Gemini'] = GeminiProvider(api_key=google_key)
103
-
104
-
105
- # Anthropic
106
- if anthropic_key := os.getenv('ANTHROPIC_API_KEY'):
107
- self.providers['Claude'] = ChatAnthropic(
108
- api_key=anthropic_key,
109
- model_name="claude-3-5-sonnet-20241022",
110
-
111
- )
112
-
113
- # OpenAI
114
- if openai_key := os.getenv('OPENAI_API_KEY'):
115
- self.providers['ChatGPT'] = ChatOpenAI(
116
- api_key=openai_key,
117
- model_name="gpt-4o-2024-11-20"
118
- )
119
-
120
-
121
- # Ollama (local)
122
- try:
123
- self.providers['Ollama-dictalm2.0'] = ChatOllama(model="dictaLM")
124
- except Exception:
125
- pass # Ollama not available
126
-
127
- def get_available_providers(self) -> list[str]:
128
- """Return list of available provider names"""
129
- return list(self.providers.keys())
130
-
131
- def get_provider(self, name: str) -> Optional[Any]:
132
- """Get LLM provider by name"""
133
- return self.providers.get(name)
 
1
+ from langchain_anthropic import ChatAnthropic
2
+ from langchain_openai import ChatOpenAI
3
+ from langchain_ollama import ChatOllama
4
+ from langchain_core.language_models.base import BaseLanguageModel
5
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
6
+ from typing import Optional, Dict, List, Any
7
+ import os
8
+ import requests
9
+ import json
10
+ from dotenv import load_dotenv
11
+ from dataclasses import dataclass
12
+
13
+
14
+ load_dotenv()
15
+
16
+
17
+ @dataclass
18
+ class GeminiResponse:
19
+ content: str
20
+
21
+
22
+ class GeminiProvider:
23
+ def __init__(self, api_key: str):
24
+ self.api_key = api_key
25
+ self.base_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
26
+
27
+ def chat(self, messages: List[Dict[str, Any]]) -> GeminiResponse:
28
+ # Convert messages to Gemini format
29
+ gemini_messages = []
30
+ for msg in messages:
31
+ # Handle both dict and LangChain message objects
32
+ if isinstance(msg, BaseMessage):
33
+ role = "user" if isinstance(msg, HumanMessage) else "model"
34
+ content = msg.content
35
+ else:
36
+ role = "user" if msg["role"] == "human" else "model"
37
+ content = msg["content"]
38
+
39
+ gemini_messages.append({
40
+ "role": role,
41
+ "parts": [{"text": content}]
42
+ })
43
+
44
+ # Prepare the request
45
+ headers = {
46
+ "Content-Type": "application/json"
47
+ }
48
+
49
+ params = {
50
+ "key": self.api_key
51
+ }
52
+
53
+ data = {
54
+ "contents": gemini_messages,
55
+ "generationConfig": {
56
+ "temperature": 0.7,
57
+ "topP": 0.8,
58
+ "topK": 40,
59
+ "maxOutputTokens": 2048,
60
+ }
61
+ }
62
+
63
+ try:
64
+ response = requests.post(
65
+ self.base_url,
66
+ headers=headers,
67
+ params=params,
68
+ json=data,
69
+
70
+ )
71
+ response.raise_for_status()
72
+
73
+ result = response.json()
74
+ if "candidates" in result and len(result["candidates"]) > 0:
75
+ return GeminiResponse(content=result["candidates"][0]["content"]["parts"][0]["text"])
76
+ else:
77
+ raise Exception("No response generated")
78
+
79
+ except Exception as e:
80
+ raise Exception(f"Error calling Gemini API: {str(e)}")
81
+
82
+ def invoke(self, messages: List[BaseMessage], **kwargs) -> GeminiResponse:
83
+ return self.chat(messages)
84
+
85
+ def generate(self, prompts, **kwargs) -> GeminiResponse:
86
+ if isinstance(prompts, str):
87
+ return self.invoke([HumanMessage(content=prompts)])
88
+ elif isinstance(prompts, list):
89
+ return self.invoke([HumanMessage(content=prompts[0])])
90
+ raise ValueError("Unsupported prompt format")
91
+
92
+ class LLMProvider:
93
+ def __init__(self):
94
+ self.providers: Dict[str, Any] = {}
95
+ self._setup_providers()
96
+
97
+ def _setup_providers(self):
98
+
99
+ # Google Gemini
100
+ if google_key := os.getenv('GOOGLE_API_KEY'):
101
+ self.providers['Gemini'] = GeminiProvider(api_key=google_key)
102
+
103
+
104
+ # Anthropic
105
+ if anthropic_key := os.getenv('ANTHROPIC_API_KEY'):
106
+ self.providers['Claude'] = ChatAnthropic(
107
+ api_key=anthropic_key,
108
+ model_name="claude-3-5-sonnet-20241022",
109
+
110
+ )
111
+
112
+ # OpenAI
113
+ if openai_key := os.getenv('OPENAI_API_KEY'):
114
+ self.providers['ChatGPT'] = ChatOpenAI(
115
+ api_key=openai_key,
116
+ model_name="gpt-4o-2024-11-20"
117
+ )
118
+
119
+
120
+
121
+
122
+ def get_available_providers(self) -> list[str]:
123
+ """Return list of available provider names"""
124
+ return list(self.providers.keys())
125
+
126
+ def get_provider(self, name: str) -> Optional[Any]:
127
+ """Get LLM provider by name"""
128
+ return self.providers.get(name)