Tora-Agent / agent.py
sivan22's picture
Update agent.py
321e99a verified
raw
history blame
1.9 kB
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from typing import Any, Iterator
from tools import search, get_commentaries, read_text
from llm_providers import LLMProvider
class Agent:
def __init__(self,index_path: str,api_keys):
self.llm_provider = LLMProvider(api_keys)
self.llm = self.llm_provider.get_provider(self.llm_provider.get_available_providers()[0])
self.memory_saver = MemorySaver()
self.tools = [read_text, get_commentaries, search]
system_prompt = open("system_prompt.txt", "r").read()
self.graph = create_react_agent(
model=self.llm,
checkpointer=self.memory_saver,
tools=self.tools,
state_modifier=system_prompt
)
self.current_thread_id = 1
def set_llm(self, provider_name: str):
self.llm = self.llm_provider.get_provider(provider_name)
system_prompt = open("system_prompt.txt", "r").read()
self.graph = create_react_agent(
model=self.llm,
checkpointer=self.memory_saver,
tools=self.tools,
state_modifier=system_prompt
)
def get_llm(self) -> str:
return self.llm
def clear_chat(self):
self.current_thread_id += 1
def chat(self, message) -> dict[str, Any]:
"""Chat with the agent and stream responses including tool calls and their results."""
config = {"configurable": {"thread_id": self.current_thread_id}}
inputs = {"messages": [("user", message)]}
return self.graph.stream(inputs,stream_mode="values", config=config)
def get_chat_history(self, id = None) -> Iterator[dict[str, Any]]:
if id is None:
id = self.current_thread_id
return self.memory_saver.get(thread_id=str(self.current_thread_id))