|
import os
|
|
import logging
|
|
from dotenv import load_dotenv
|
|
from langchain.memory import ConversationSummaryMemory
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_community.utilities import SQLDatabase
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_openai import OpenAIEmbeddings
|
|
from langchain.agents import create_tool_calling_agent, AgentExecutor, Tool
|
|
from langchain_community.vectorstores import FAISS
|
|
from config.settings import Settings
|
|
|
|
|
|
load_dotenv()
|
|
open_api_key_token = os.getenv('OPENAI_API_KEY')
|
|
|
|
db_uri = Settings.DB_URI
|
|
|
|
class ChatAgentService:
|
|
def __init__(self):
|
|
|
|
self.db = SQLDatabase.from_uri(db_uri)
|
|
self.llm = ChatOpenAI(model="gpt-3.5-turbo-0125", api_key=open_api_key_token,max_tokens=150,temperature=0.2)
|
|
self.memory = ConversationSummaryMemory(llm=self.llm, return_messages=True)
|
|
|
|
|
|
|
|
self.tools = [
|
|
Tool(
|
|
name="DatabaseQuery",
|
|
func=self.database_tool,
|
|
description="Queries the SQL database using dynamically generated SQL queries based on user questions. Aimed to retrieve structured data like counts, specific records, or summaries from predefined schemas.",
|
|
tool_choice="required"
|
|
),
|
|
Tool(
|
|
name="DocumentData",
|
|
func=self.document_data_tool,
|
|
description="Searches through indexed documents to find relevant information based on user queries. Handles unstructured data from various document formats like PDF, DOCX, or TXT files.",
|
|
tool_choice="required"
|
|
),
|
|
]
|
|
|
|
|
|
prompt_template = self.setup_prompt()
|
|
self.agent = create_tool_calling_agent(self.llm.bind(memory=self.memory), self.tools, prompt_template)
|
|
self.agent_executor = AgentExecutor(agent=self.agent, tools=self.tools, memory=self.memory, verbose=True)
|
|
|
|
def setup_prompt(self):
|
|
prompt_template = f"""
|
|
You are an assistant that helps with database queries and document retrieval.
|
|
Please base your responses strictly on available data and avoid assumptions.
|
|
If the question pertains to numerical data or structured queries, use the DatabaseQuery tool.
|
|
If the question relates to content within various documents, use the DocumentData tool.
|
|
Question: {{input}}
|
|
{{agent_scratchpad}}
|
|
"""
|
|
return ChatPromptTemplate.from_template(prompt_template)
|
|
|
|
def database_tool(self, question):
|
|
sql_query = self.generate_sql_query(question)
|
|
return self.run_query(sql_query)
|
|
|
|
def get_schema(self,_):
|
|
|
|
return self.db.get_table_info()
|
|
def generate_sql_query(self, question):
|
|
schema = self.get_schema(None)
|
|
template_query_generation = """Generate a SQL query to answer the user's question based on the available database schema.
|
|
{schema}
|
|
Question: {question}
|
|
SQL Query:"""
|
|
|
|
prompt_query_generation = ChatPromptTemplate.from_template(template_query_generation)
|
|
|
|
input_data = {'question': question}
|
|
|
|
sql_chain = (RunnablePassthrough.assign(schema=self.get_schema)
|
|
| prompt_query_generation
|
|
| self.llm.bind(stop="\nSQL Result:")
|
|
| StrOutputParser())
|
|
|
|
|
|
return sql_chain.invoke(input_data)
|
|
|
|
def run_query(self, query):
|
|
try:
|
|
logging.info(f"Executing SQL query: {query}")
|
|
result = self.db.run(query)
|
|
logging.info(f"Query successful: {result}")
|
|
return result
|
|
except Exception as e:
|
|
logging.error(f"Error executing query: {query}, Error: {str(e)}")
|
|
return None
|
|
|
|
def document_data_tool(self, query):
|
|
try:
|
|
logging.info(f"Searching documents for query: {query}")
|
|
embeddings = OpenAIEmbeddings(api_key=open_api_key_token)
|
|
index_paths = self.find_index_for_document(query)
|
|
responses = []
|
|
for index_path in index_paths:
|
|
vector_store = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
|
|
response = self.query_vector_store(vector_store, query)
|
|
responses.append(response)
|
|
logging.info(f"Document search results: {responses}")
|
|
return "\n".join(responses)
|
|
except Exception as e:
|
|
logging.error(f"Error in document data tool for query: {query}, Error: {str(e)}")
|
|
return "Error processing document query."
|
|
|
|
def find_index_for_document(self, query):
|
|
base_path = os.getenv('VECTOR_DB_PATH')
|
|
|
|
index_paths = []
|
|
for root, dirs, files in os.walk(base_path):
|
|
for dir in dirs:
|
|
if 'index.faiss' in os.listdir(os.path.join(root, dir)):
|
|
index_paths.append(os.path.join(root, dir, ''))
|
|
return index_paths
|
|
|
|
def query_vector_store(self, vector_store, query):
|
|
docs = vector_store.similarity_search(query)
|
|
return '\n\n'.join([doc.page_content for doc in docs])
|
|
|
|
def answer_question(self, user_question):
|
|
try:
|
|
logging.info(f"Received question: {user_question}")
|
|
response = self.agent_executor.invoke({"input": user_question})
|
|
output_response = response.get("output", "No valid response generated.")
|
|
logging.info(f"Response generated: {output_response}")
|
|
return output_response
|
|
except Exception as e:
|
|
logging.error(f"Error processing question: {user_question}, Error: {str(e)}")
|
|
return f"An error occurred: {str(e)}"
|
|
|
|
|