Spaces:
Running
Running
import json | |
from random import choices | |
import string | |
from langchain.tools import BaseTool | |
import requests | |
from dotenv import load_dotenv | |
from dataclasses import dataclass | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
List, | |
Literal, | |
Mapping, | |
Optional, | |
Sequence, | |
Type, | |
Union, | |
cast, | |
) | |
from langchain_core.callbacks import ( | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun | |
from langchain_core.exceptions import OutputParserException | |
from langchain_core.language_models import LanguageModelInput | |
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams | |
from langchain_core.messages import ( | |
AIMessage, | |
BaseMessage, | |
HumanMessage, | |
ToolMessage, | |
SystemMessage, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatResult | |
from langchain_core.runnables import Runnable | |
from langchain_core.tools import BaseTool | |
class ChatGemini(BaseChatModel): | |
def _llm_type(self) -> str: | |
"""Get the type of language model used by this chat model.""" | |
return "gemini" | |
api_key :str | |
base_url:str = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent" | |
model_kwargs: Any = {} | |
def _generate( | |
self, | |
messages: list[BaseMessage], | |
stop: Optional[list[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
"""Generate a chat response using the Gemini API. | |
This method handles both regular text responses and function calls. | |
For function calls, it returns a ToolMessage with structured function call data | |
that can be processed by Langchain's agent executor. | |
Function calls are returned with: | |
- tool_name: The name of the function to call | |
- tool_call_id: A unique identifier for the function call (name is used as Gemini doesn't provide one) | |
- content: The function arguments as a JSON string | |
- additional_kwargs: Contains the full function call details | |
Args: | |
messages: List of input messages | |
stop: Optional list of stop sequences | |
run_manager: Optional callback manager | |
**kwargs: Additional arguments passed to the Gemini API | |
Returns: | |
ChatResult containing either an AIMessage for text responses | |
or a ToolMessage for function calls | |
""" | |
# Convert messages to Gemini format | |
gemini_messages = [] | |
system_message = None | |
for msg in messages: | |
# Handle both dict and LangChain message objects | |
if isinstance(msg, BaseMessage): | |
if isinstance(msg, SystemMessage): | |
system_message = msg.content | |
kwargs["system_instruction"]= {"parts":[{"text": system_message}]} | |
continue | |
if isinstance(msg, HumanMessage): | |
role = "user" | |
content = msg.content | |
elif isinstance(msg, AIMessage): | |
role = "model" | |
content = msg.content | |
elif isinstance(msg, ToolMessage): | |
# Handle tool messages by adding them as function outputs | |
gemini_messages.append( | |
{ | |
"role": "model", | |
"parts": [{ | |
"functionResponse": { | |
"name": msg.name, | |
"response": {"name": msg.name, "content": msg.content}, | |
}}]} | |
) | |
continue | |
else: | |
role = "user" if msg["role"] == "human" else "model" | |
content = msg["content"] | |
message_part = { | |
"role": role, | |
"parts":[{"functionCall": { "name": msg.tool_calls[0]["name"], "args": msg.tool_calls[0]["args"]}}] if isinstance(msg, AIMessage) and msg.tool_calls else [{"text": content}] | |
} | |
gemini_messages.append(message_part) | |
# Prepare the request | |
headers = { | |
"Content-Type": "application/json" | |
} | |
params = { | |
"key": self.api_key | |
} | |
data = { | |
"contents": gemini_messages, | |
"generationConfig": { | |
"temperature": 0.7, | |
"topP": 0.8, | |
"topK": 40, | |
"maxOutputTokens": 2048, | |
}, | |
**kwargs | |
} | |
try: | |
response = requests.post( | |
self.base_url, | |
headers=headers, | |
params=params, | |
json=data, | |
) | |
response.raise_for_status() | |
result = response.json() | |
if "candidates" in result and len(result["candidates"]) > 0 and "parts" in result["candidates"][0]["content"]: | |
parts = result["candidates"][0]["content"]["parts"] | |
tool_calls = [] | |
content = "" | |
for part in parts: | |
if "text" in part: | |
content += part["text"] | |
if "functionCall" in part: | |
function_call = part["functionCall"] | |
tool_calls.append( { | |
"name": function_call["name"], | |
"id": function_call["name"]+random_string(5), # Gemini doesn't provide a unique id,} | |
"args": function_call["args"], | |
"type": "tool_call",}) | |
# Create a proper ToolMessage with structured function call data | |
return ChatResult(generations=[ | |
ChatGeneration( | |
message=AIMessage( | |
content=content, | |
tool_calls=tool_calls, | |
) if len(tool_calls) > 0 else AIMessage(content=content) | |
) | |
]) | |
else: | |
raise Exception("No response generated") | |
except Exception as e: | |
raise Exception(f"Error calling Gemini API: {str(e)}") | |
def bind_tools( | |
self, | |
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], | |
*, | |
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None, | |
**kwargs: Any, | |
) -> Runnable[LanguageModelInput, BaseMessage]: | |
"""Bind tool-like objects to this chat model. | |
Args: | |
tools: A list of tool definitions to bind to this chat model. | |
Supports any tool definition handled by | |
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. | |
tool_choice: If provided, which tool for model to call. **This parameter | |
is currently ignored as it is not supported by Ollama.** | |
kwargs: Any additional parameters are passed directly to | |
``self.bind(**kwargs)``. | |
""" | |
formatted_tools = {"function_declarations": [convert_to_gemini_tool(tool) for tool in tools]} | |
return super().bind(tools=formatted_tools, **kwargs) | |
def convert_to_gemini_tool( | |
tool: Union[BaseTool], | |
*, | |
strict: Optional[bool] = None, | |
) -> dict[str, Any]: | |
"""Convert a tool-like object to an Gemini tool schema. | |
Gemini tool schema reference: | |
https://ai.google.dev/gemini-api/docs/function-calling#function_calling_mode | |
Args: | |
tool: | |
BaseTool. | |
strict: | |
If True, model output is guaranteed to exactly match the JSON Schema | |
provided in the function definition. If None, ``strict`` argument will not | |
be included in tool definition. | |
Returns: | |
A dict version of the passed in tool which is compatible with the | |
Gemini tool-calling API. | |
""" | |
if isinstance(tool, BaseTool): | |
# Extract the tool's schema | |
schema = tool.args_schema.schema() if tool.args_schema else {"type": "object", "properties": {}} | |
#convert to gemini schema | |
raw_properties = schema.get("properties", {}) | |
properties = {} | |
for key, value in raw_properties.items(): | |
properties[key] = { | |
"type": value.get("type", "string"), | |
"description": value.get("title", ""), | |
} | |
# Build the function definition | |
function_def = { | |
"name": tool.name, | |
"description": tool.description, | |
"parameters": { | |
"type": "object", | |
"properties": properties, | |
"required": schema.get("required", []) | |
} | |
} | |
if strict is not None: | |
function_def["strict"] = strict | |
return function_def | |
else: | |
raise ValueError(f"Unsupported tool type: {type(tool)}") | |
def random_string(length: int) -> str: | |
return ''.join(choices(string.ascii_letters + string.digits, k=length)) | |