Ashvanth.S
update files
56ca6bf
import os
from dotenv import load_dotenv
from utils.embeddings import get_embeddings
from utils.vector_store import load_vector_store
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from langchain_community.tools import DuckDuckGoSearchResults
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.tools import tool
from langchain_openai import ChatOpenAI
from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.messages import AIMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableWithMessageHistory
load_dotenv()
wrapper = DuckDuckGoSearchAPIWrapper(max_results=2)
search_web = DuckDuckGoSearchResults(api_wrapper=wrapper, source="news")
@tool
def rag_tool(query:str)->str:
"""
The function queries the vector db and retrieves the answer
"""
embeddings = get_embeddings()
vector_store = load_vector_store(embeddings)
retriver = vector_store.as_retriever(search_type='similarity',search_kwargs={"k": 2})
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise."
"\n\n"
"{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(get_model_use(), prompt)
rag_chain = create_retrieval_chain(retriver, question_answer_chain)
response = rag_chain.invoke({"input":query})
return (response["answer"])
@tool
def calculate_word_count(words: str) -> int:
"""
The function helps in calculating the number of words present in the responses
"""
response = words.split()
return len(response)
tools = [rag_tool,search_web, calculate_word_count]
prompt = ChatPromptTemplate.from_messages([
("system", """
You are an assistant helping the user with queries related to the NCERT Sound chapter and real-time events or factual information. Follow these instructions:
1. **NCERT Sound Chapter Queries**:
- Use your rag tool provide to answer any query regarding NCERT Sound chapter to answer questions such as:
- What is an echo?
- How is sound propagated?
- What are the applications of ultrasound?
- STRICT RULE: Do not use external tools after using rag_tool
2. **Non-Sound Chapter Queries**:
- For any questions unrelated to the Sound chapter, such as real-time events, news, or factual information not covered in the Sound chapter, use the search tool to provide the latest and most accurate information.
3. **Counting Words in a Response**:
- If the query involves counting the number of words in a response, use the `calculate_word_count` tool to determine the word count.
4. **Clarification**:
- If the query is unclear or ambiguous, clarify the user's intent before selecting the appropriate tool or providing a response.
Be concise, accurate, and use the appropriate tool or knowledge based on the query type. Do not confuse the tools or mix the instructions for different query types.
"""),
MessagesPlaceholder(variable_name="chat_history"),
("user", "Form input details: {input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
def get_model_use():
return ChatOpenAI(api_key=os.environ.get("OPEN_API_KEY"),temperature=0)
def init_agent():
llm = get_model_use()
agent = create_openai_functions_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
message_history = ChatMessageHistory()
agent_with_chat_history = RunnableWithMessageHistory(
agent_executor,
lambda session_id: message_history,
input_messages_key="input",
history_messages_key="chat_history",
)
return agent_with_chat_history
def get_agent_response(agent, user_input, session_id="agentic_trial"):
response = agent.invoke(
{
"input": user_input
},
config={"configurable": {"session_id": session_id}}
)
return response['output']