Spaces:
Running
Running
import json | |
# Import relevant functionality | |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.prebuilt import create_react_agent | |
from langchain_anthropic import ChatAnthropic | |
import sys | |
sys.path.append('.') | |
from toolformers.base import Tool as AgoraTool | |
from langchain_core.tools import tool as function_to_tool | |
from toolformers.base import StringParameter, Toolformer, Conversation | |
from utils import register_cost | |
COSTS = { | |
'claude-3-5-sonnet-latest': { | |
'input_tokens': 3e-6, | |
'output_tokens': 15e-6 | |
}, | |
'claude-3-5-haiku-latest': { | |
'input_tokens': 1e-6, | |
'output_tokens': 5e-6 | |
} | |
} | |
class LangChainConversation(Conversation): | |
def __init__(self, model_name, agent, messages, category=None): | |
self.model_name = model_name | |
self.agent = agent | |
self.messages = messages | |
self.category = category | |
def chat(self, message, role='user', print_output=True) -> str: | |
self.messages.append(HumanMessage(content=message)) | |
final_message = '' | |
aggregate = None | |
for chunk in self.agent.stream({"messages": self.messages}, stream_mode="values"): | |
print(chunk) | |
print("----") | |
for message in chunk['messages']: | |
if isinstance(message, AIMessage): | |
content = message.content | |
if isinstance(content, str): | |
final_message += content | |
else: | |
for content_chunk in content: | |
if isinstance(content_chunk, str): | |
final_message += content_chunk | |
aggregate = chunk if aggregate is None else (aggregate + chunk) | |
#final_message += chunk['agent']['messages'].content | |
total_cost = 0 | |
for message in aggregate['messages']: | |
if isinstance(message, AIMessage): | |
for cost_name in ['input_tokens', 'output_tokens']: | |
total_cost += COSTS[self.model_name][cost_name] * message.usage_metadata[cost_name] | |
register_cost(self.category, total_cost) | |
self.messages.append(AIMessage(content=final_message)) | |
#print(final_message) | |
return final_message | |
class LangChainAnthropicToolformer(Toolformer): | |
def __init__(self, model_name, api_key): | |
self.model_name = model_name | |
self.api_key = api_key | |
def new_conversation(self, prompt, tools, category=None): | |
tools = [function_to_tool(tool.as_annotated_function()) for tool in tools] | |
model = ChatAnthropic(model_name=self.model_name, api_key=self.api_key) | |
agent_executor = create_react_agent(model, tools) | |
return LangChainConversation(self.model_name, agent_executor, [SystemMessage(prompt)], category) | |
#weather_tool = AgoraTool("WeatherForecastAPI", "A simple tool that returns the weather", [StringParameter( | |
# name="location", | |
# description="The name of the location for which the weather forecast is requested.", | |
# required=True | |
#)], lambda location: 'Sunny', { | |
# "type": "string" | |
#}) | |
# | |
#tools = [agora_tool_to_langchain(weather_tool)] | |
#toolformer = LangChainToolformer("claude-3-sonnet-20240229", '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') | |
#conversation = toolformer.new_conversation('You are a weather bot', [weather_tool]) | |
# | |
#print(conversation.chat('What is the weather in San Francisco?')) | |