agora-demo / toolformers /langchain_agent.py
Samuele Marro
Added cost tracking.
c07f594
raw
history blame
3.59 kB
import json
# Import relevant functionality
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from langchain_anthropic import ChatAnthropic
import sys
sys.path.append('.')
from toolformers.base import Tool as AgoraTool
from langchain_core.tools import tool as function_to_tool
from toolformers.base import StringParameter, Toolformer, Conversation
from utils import register_cost
COSTS = {
'claude-3-5-sonnet-latest': {
'input_tokens': 3e-6,
'output_tokens': 15e-6
},
'claude-3-5-haiku-latest': {
'input_tokens': 1e-6,
'output_tokens': 5e-6
}
}
class LangChainConversation(Conversation):
def __init__(self, model_name, agent, messages, category=None):
self.model_name = model_name
self.agent = agent
self.messages = messages
self.category = category
def chat(self, message, role='user', print_output=True) -> str:
self.messages.append(HumanMessage(content=message))
final_message = ''
aggregate = None
for chunk in self.agent.stream({"messages": self.messages}, stream_mode="values"):
print(chunk)
print("----")
for message in chunk['messages']:
if isinstance(message, AIMessage):
content = message.content
if isinstance(content, str):
final_message += content
else:
for content_chunk in content:
if isinstance(content_chunk, str):
final_message += content_chunk
aggregate = chunk if aggregate is None else (aggregate + chunk)
#final_message += chunk['agent']['messages'].content
total_cost = 0
for message in aggregate['messages']:
if isinstance(message, AIMessage):
for cost_name in ['input_tokens', 'output_tokens']:
total_cost += COSTS[self.model_name][cost_name] * message.usage_metadata[cost_name]
register_cost(self.category, total_cost)
self.messages.append(AIMessage(content=final_message))
#print(final_message)
return final_message
class LangChainAnthropicToolformer(Toolformer):
def __init__(self, model_name, api_key):
self.model_name = model_name
self.api_key = api_key
def new_conversation(self, prompt, tools, category=None):
tools = [function_to_tool(tool.as_annotated_function()) for tool in tools]
model = ChatAnthropic(model_name=self.model_name, api_key=self.api_key)
agent_executor = create_react_agent(model, tools)
return LangChainConversation(self.model_name, agent_executor, [SystemMessage(prompt)], category)
#weather_tool = AgoraTool("WeatherForecastAPI", "A simple tool that returns the weather", [StringParameter(
# name="location",
# description="The name of the location for which the weather forecast is requested.",
# required=True
#)], lambda location: 'Sunny', {
# "type": "string"
#})
#
#tools = [agora_tool_to_langchain(weather_tool)]
#toolformer = LangChainToolformer("claude-3-sonnet-20240229", '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')
#conversation = toolformer.new_conversation('You are a weather bot', [weather_tool])
#
#print(conversation.chat('What is the weather in San Francisco?'))