sivan22 commited on
Commit
b917882
·
verified ·
1 Parent(s): 2c517f0

Update chat_gemini.py

Browse files
Files changed (1) hide show
  1. chat_gemini.py +262 -263
chat_gemini.py CHANGED
@@ -1,264 +1,263 @@
1
- import json
2
- from random import choices
3
- import string
4
- from langchain.tools import BaseTool
5
- import requests
6
- from dotenv import load_dotenv
7
- from dataclasses import dataclass
8
- from langchain_core.language_models.chat_models import BaseChatModel
9
- from typing import (
10
- Any,
11
- Callable,
12
- Dict,
13
- List,
14
- Literal,
15
- Mapping,
16
- Optional,
17
- Sequence,
18
- Type,
19
- Union,
20
- cast,
21
- )
22
- from langchain_core.callbacks import (
23
- CallbackManagerForLLMRun,
24
- )
25
- from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
26
- from langchain_core.exceptions import OutputParserException
27
- from langchain_core.language_models import LanguageModelInput
28
- from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
29
- from langchain_core.messages import (
30
- AIMessage,
31
- BaseMessage,
32
- HumanMessage,
33
- ToolMessage,
34
- SystemMessage,
35
- )
36
- from langchain_core.outputs import ChatGeneration, ChatResult
37
- from langchain_core.runnables import Runnable
38
- from langchain_core.tools import BaseTool
39
-
40
-
41
- class ChatGemini(BaseChatModel):
42
-
43
- @property
44
- def _llm_type(self) -> str:
45
- """Get the type of language model used by this chat model."""
46
- return "gemini"
47
-
48
- api_key :str
49
- base_url:str = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"
50
- model_kwargs: Any = {}
51
-
52
- def _generate(
53
- self,
54
- messages: list[BaseMessage],
55
- stop: Optional[list[str]] = None,
56
- run_manager: Optional[CallbackManagerForLLMRun] = None,
57
- **kwargs: Any,
58
- ) -> ChatResult:
59
- """Generate a chat response using the Gemini API.
60
-
61
- This method handles both regular text responses and function calls.
62
- For function calls, it returns a ToolMessage with structured function call data
63
- that can be processed by Langchain's agent executor.
64
-
65
- Function calls are returned with:
66
- - tool_name: The name of the function to call
67
- - tool_call_id: A unique identifier for the function call (name is used as Gemini doesn't provide one)
68
- - content: The function arguments as a JSON string
69
- - additional_kwargs: Contains the full function call details
70
-
71
- Args:
72
- messages: List of input messages
73
- stop: Optional list of stop sequences
74
- run_manager: Optional callback manager
75
- **kwargs: Additional arguments passed to the Gemini API
76
-
77
- Returns:
78
- ChatResult containing either an AIMessage for text responses
79
- or a ToolMessage for function calls
80
- """
81
- # Convert messages to Gemini format
82
- gemini_messages = []
83
- system_message = None
84
- for msg in messages:
85
- # Handle both dict and LangChain message objects
86
- if isinstance(msg, BaseMessage):
87
- if isinstance(msg, SystemMessage):
88
- system_message = msg.content
89
- kwargs["system_instruction"]= {"parts":[{"text": system_message}]}
90
- continue
91
- if isinstance(msg, HumanMessage):
92
- role = "user"
93
- content = msg.content
94
- elif isinstance(msg, AIMessage):
95
- role = "model"
96
- content = msg.content
97
- elif isinstance(msg, ToolMessage):
98
- # Handle tool messages by adding them as function outputs
99
- gemini_messages.append(
100
- {
101
- "role": "model",
102
- "parts": [{
103
- "functionResponse": {
104
- "name": msg.name,
105
- "response": {"name": msg.name, "content": msg.content},
106
- }}]}
107
- )
108
- continue
109
- else:
110
- role = "user" if msg["role"] == "human" else "model"
111
- content = msg["content"]
112
-
113
- message_part = {
114
- "role": role,
115
- "parts":[{"functionCall": { "name": msg.tool_calls[0]["name"], "args": msg.tool_calls[0]["args"]}}] if isinstance(msg, AIMessage) and msg.tool_calls else [{"text": content}]
116
- }
117
- gemini_messages.append(message_part)
118
-
119
-
120
-
121
- # Prepare the request
122
- headers = {
123
- "Content-Type": "application/json"
124
- }
125
-
126
- params = {
127
- "key": self.api_key
128
- }
129
-
130
- data = {
131
- "contents": gemini_messages,
132
- "generationConfig": {
133
- "temperature": 0.7,
134
- "topP": 0.8,
135
- "topK": 40,
136
- "maxOutputTokens": 2048,
137
- },
138
- **kwargs
139
- }
140
-
141
-
142
- try:
143
- response = requests.post(
144
- self.base_url,
145
- headers=headers,
146
- params=params,
147
- json=data,
148
- verify='C:\\ProgramData\\NetFree\\CA\\netfree-ca-bundle-curl.crt'
149
- )
150
- response.raise_for_status()
151
-
152
- result = response.json()
153
- if "candidates" in result and len(result["candidates"]) > 0 and "parts" in result["candidates"][0]["content"]:
154
- parts = result["candidates"][0]["content"]["parts"]
155
- tool_calls = []
156
- content = ""
157
- for part in parts:
158
- if "text" in part:
159
- content += part["text"]
160
- if "functionCall" in part:
161
- function_call = part["functionCall"]
162
- tool_calls.append( {
163
- "name": function_call["name"],
164
- "id": function_call["name"]+random_string(5), # Gemini doesn't provide a unique id,}
165
- "args": function_call["args"],
166
- "type": "tool_call",})
167
- # Create a proper ToolMessage with structured function call data
168
- return ChatResult(generations=[
169
- ChatGeneration(
170
- message=AIMessage(
171
- content=content,
172
- tool_calls=tool_calls,
173
- ) if len(tool_calls) > 0 else AIMessage(content=content)
174
- )
175
- ])
176
-
177
-
178
- else:
179
- raise Exception("No response generated")
180
-
181
- except Exception as e:
182
- raise Exception(f"Error calling Gemini API: {str(e)}")
183
-
184
-
185
- def bind_tools(
186
- self,
187
- tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
188
- *,
189
- tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
190
- **kwargs: Any,
191
- ) -> Runnable[LanguageModelInput, BaseMessage]:
192
- """Bind tool-like objects to this chat model.
193
-
194
-
195
- Args:
196
- tools: A list of tool definitions to bind to this chat model.
197
- Supports any tool definition handled by
198
- :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
199
- tool_choice: If provided, which tool for model to call. **This parameter
200
- is currently ignored as it is not supported by Ollama.**
201
- kwargs: Any additional parameters are passed directly to
202
- ``self.bind(**kwargs)``.
203
- """
204
-
205
- formatted_tools = {"function_declarations": [convert_to_gemini_tool(tool) for tool in tools]}
206
- return super().bind(tools=formatted_tools, **kwargs)
207
-
208
- def convert_to_gemini_tool(
209
- tool: Union[BaseTool],
210
- *,
211
- strict: Optional[bool] = None,
212
- ) -> dict[str, Any]:
213
- """Convert a tool-like object to an Gemini tool schema.
214
-
215
- Gemini tool schema reference:
216
- https://ai.google.dev/gemini-api/docs/function-calling#function_calling_mode
217
-
218
- Args:
219
- tool:
220
- BaseTool.
221
- strict:
222
- If True, model output is guaranteed to exactly match the JSON Schema
223
- provided in the function definition. If None, ``strict`` argument will not
224
- be included in tool definition.
225
-
226
- Returns:
227
- A dict version of the passed in tool which is compatible with the
228
- Gemini tool-calling API.
229
- """
230
- if isinstance(tool, BaseTool):
231
- # Extract the tool's schema
232
- schema = tool.args_schema.schema() if tool.args_schema else {"type": "object", "properties": {}}
233
-
234
- #convert to gemini schema
235
- raw_properties = schema.get("properties", {})
236
- properties = {}
237
- for key, value in raw_properties.items():
238
- properties[key] = {
239
- "type": value.get("type", "string"),
240
- "description": value.get("title", ""),
241
- }
242
-
243
-
244
- # Build the function definition
245
- function_def = {
246
- "name": tool.name,
247
- "description": tool.description,
248
- "parameters": {
249
- "type": "object",
250
- "properties": properties,
251
- "required": schema.get("required", [])
252
- }
253
- }
254
-
255
- if strict is not None:
256
- function_def["strict"] = strict
257
-
258
- return function_def
259
- else:
260
- raise ValueError(f"Unsupported tool type: {type(tool)}")
261
-
262
- def random_string(length: int) -> str:
263
- return ''.join(choices(string.ascii_letters + string.digits, k=length))
264
 
 
1
+ import json
2
+ from random import choices
3
+ import string
4
+ from langchain.tools import BaseTool
5
+ import requests
6
+ from dotenv import load_dotenv
7
+ from dataclasses import dataclass
8
+ from langchain_core.language_models.chat_models import BaseChatModel
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ List,
14
+ Literal,
15
+ Mapping,
16
+ Optional,
17
+ Sequence,
18
+ Type,
19
+ Union,
20
+ cast,
21
+ )
22
+ from langchain_core.callbacks import (
23
+ CallbackManagerForLLMRun,
24
+ )
25
+ from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
26
+ from langchain_core.exceptions import OutputParserException
27
+ from langchain_core.language_models import LanguageModelInput
28
+ from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
29
+ from langchain_core.messages import (
30
+ AIMessage,
31
+ BaseMessage,
32
+ HumanMessage,
33
+ ToolMessage,
34
+ SystemMessage,
35
+ )
36
+ from langchain_core.outputs import ChatGeneration, ChatResult
37
+ from langchain_core.runnables import Runnable
38
+ from langchain_core.tools import BaseTool
39
+
40
+
41
+ class ChatGemini(BaseChatModel):
42
+
43
+ @property
44
+ def _llm_type(self) -> str:
45
+ """Get the type of language model used by this chat model."""
46
+ return "gemini"
47
+
48
+ api_key :str
49
+ base_url:str = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"
50
+ model_kwargs: Any = {}
51
+
52
+ def _generate(
53
+ self,
54
+ messages: list[BaseMessage],
55
+ stop: Optional[list[str]] = None,
56
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
57
+ **kwargs: Any,
58
+ ) -> ChatResult:
59
+ """Generate a chat response using the Gemini API.
60
+
61
+ This method handles both regular text responses and function calls.
62
+ For function calls, it returns a ToolMessage with structured function call data
63
+ that can be processed by Langchain's agent executor.
64
+
65
+ Function calls are returned with:
66
+ - tool_name: The name of the function to call
67
+ - tool_call_id: A unique identifier for the function call (name is used as Gemini doesn't provide one)
68
+ - content: The function arguments as a JSON string
69
+ - additional_kwargs: Contains the full function call details
70
+
71
+ Args:
72
+ messages: List of input messages
73
+ stop: Optional list of stop sequences
74
+ run_manager: Optional callback manager
75
+ **kwargs: Additional arguments passed to the Gemini API
76
+
77
+ Returns:
78
+ ChatResult containing either an AIMessage for text responses
79
+ or a ToolMessage for function calls
80
+ """
81
+ # Convert messages to Gemini format
82
+ gemini_messages = []
83
+ system_message = None
84
+ for msg in messages:
85
+ # Handle both dict and LangChain message objects
86
+ if isinstance(msg, BaseMessage):
87
+ if isinstance(msg, SystemMessage):
88
+ system_message = msg.content
89
+ kwargs["system_instruction"]= {"parts":[{"text": system_message}]}
90
+ continue
91
+ if isinstance(msg, HumanMessage):
92
+ role = "user"
93
+ content = msg.content
94
+ elif isinstance(msg, AIMessage):
95
+ role = "model"
96
+ content = msg.content
97
+ elif isinstance(msg, ToolMessage):
98
+ # Handle tool messages by adding them as function outputs
99
+ gemini_messages.append(
100
+ {
101
+ "role": "model",
102
+ "parts": [{
103
+ "functionResponse": {
104
+ "name": msg.name,
105
+ "response": {"name": msg.name, "content": msg.content},
106
+ }}]}
107
+ )
108
+ continue
109
+ else:
110
+ role = "user" if msg["role"] == "human" else "model"
111
+ content = msg["content"]
112
+
113
+ message_part = {
114
+ "role": role,
115
+ "parts":[{"functionCall": { "name": msg.tool_calls[0]["name"], "args": msg.tool_calls[0]["args"]}}] if isinstance(msg, AIMessage) and msg.tool_calls else [{"text": content}]
116
+ }
117
+ gemini_messages.append(message_part)
118
+
119
+
120
+
121
+ # Prepare the request
122
+ headers = {
123
+ "Content-Type": "application/json"
124
+ }
125
+
126
+ params = {
127
+ "key": self.api_key
128
+ }
129
+
130
+ data = {
131
+ "contents": gemini_messages,
132
+ "generationConfig": {
133
+ "temperature": 0.7,
134
+ "topP": 0.8,
135
+ "topK": 40,
136
+ "maxOutputTokens": 2048,
137
+ },
138
+ **kwargs
139
+ }
140
+
141
+
142
+ try:
143
+ response = requests.post(
144
+ self.base_url,
145
+ headers=headers,
146
+ params=params,
147
+ json=data,
148
+ )
149
+ response.raise_for_status()
150
+
151
+ result = response.json()
152
+ if "candidates" in result and len(result["candidates"]) > 0 and "parts" in result["candidates"][0]["content"]:
153
+ parts = result["candidates"][0]["content"]["parts"]
154
+ tool_calls = []
155
+ content = ""
156
+ for part in parts:
157
+ if "text" in part:
158
+ content += part["text"]
159
+ if "functionCall" in part:
160
+ function_call = part["functionCall"]
161
+ tool_calls.append( {
162
+ "name": function_call["name"],
163
+ "id": function_call["name"]+random_string(5), # Gemini doesn't provide a unique id,}
164
+ "args": function_call["args"],
165
+ "type": "tool_call",})
166
+ # Create a proper ToolMessage with structured function call data
167
+ return ChatResult(generations=[
168
+ ChatGeneration(
169
+ message=AIMessage(
170
+ content=content,
171
+ tool_calls=tool_calls,
172
+ ) if len(tool_calls) > 0 else AIMessage(content=content)
173
+ )
174
+ ])
175
+
176
+
177
+ else:
178
+ raise Exception("No response generated")
179
+
180
+ except Exception as e:
181
+ raise Exception(f"Error calling Gemini API: {str(e)}")
182
+
183
+
184
+ def bind_tools(
185
+ self,
186
+ tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
187
+ *,
188
+ tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
189
+ **kwargs: Any,
190
+ ) -> Runnable[LanguageModelInput, BaseMessage]:
191
+ """Bind tool-like objects to this chat model.
192
+
193
+
194
+ Args:
195
+ tools: A list of tool definitions to bind to this chat model.
196
+ Supports any tool definition handled by
197
+ :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
198
+ tool_choice: If provided, which tool for model to call. **This parameter
199
+ is currently ignored as it is not supported by Ollama.**
200
+ kwargs: Any additional parameters are passed directly to
201
+ ``self.bind(**kwargs)``.
202
+ """
203
+
204
+ formatted_tools = {"function_declarations": [convert_to_gemini_tool(tool) for tool in tools]}
205
+ return super().bind(tools=formatted_tools, **kwargs)
206
+
207
+ def convert_to_gemini_tool(
208
+ tool: Union[BaseTool],
209
+ *,
210
+ strict: Optional[bool] = None,
211
+ ) -> dict[str, Any]:
212
+ """Convert a tool-like object to an Gemini tool schema.
213
+
214
+ Gemini tool schema reference:
215
+ https://ai.google.dev/gemini-api/docs/function-calling#function_calling_mode
216
+
217
+ Args:
218
+ tool:
219
+ BaseTool.
220
+ strict:
221
+ If True, model output is guaranteed to exactly match the JSON Schema
222
+ provided in the function definition. If None, ``strict`` argument will not
223
+ be included in tool definition.
224
+
225
+ Returns:
226
+ A dict version of the passed in tool which is compatible with the
227
+ Gemini tool-calling API.
228
+ """
229
+ if isinstance(tool, BaseTool):
230
+ # Extract the tool's schema
231
+ schema = tool.args_schema.schema() if tool.args_schema else {"type": "object", "properties": {}}
232
+
233
+ #convert to gemini schema
234
+ raw_properties = schema.get("properties", {})
235
+ properties = {}
236
+ for key, value in raw_properties.items():
237
+ properties[key] = {
238
+ "type": value.get("type", "string"),
239
+ "description": value.get("title", ""),
240
+ }
241
+
242
+
243
+ # Build the function definition
244
+ function_def = {
245
+ "name": tool.name,
246
+ "description": tool.description,
247
+ "parameters": {
248
+ "type": "object",
249
+ "properties": properties,
250
+ "required": schema.get("required", [])
251
+ }
252
+ }
253
+
254
+ if strict is not None:
255
+ function_def["strict"] = strict
256
+
257
+ return function_def
258
+ else:
259
+ raise ValueError(f"Unsupported tool type: {type(tool)}")
260
+
261
+ def random_string(length: int) -> str:
262
+ return ''.join(choices(string.ascii_letters + string.digits, k=length))
 
263