File size: 6,370 Bytes
8bac072 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import logging
from dotenv import load_dotenv
from langchain.memory import ConversationSummaryMemory
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain.agents import create_tool_calling_agent, AgentExecutor, Tool
from langchain_community.vectorstores import FAISS
from config.settings import Settings
# Load environment variables
load_dotenv()
open_api_key_token = os.getenv('OPENAI_API_KEY')
#db_uri = os.getenv('POST_DB_URI')
db_uri = Settings.DB_URI
class ChatAgentService:
def __init__(self):
# Database setup
self.db = SQLDatabase.from_uri(db_uri)
self.llm = ChatOpenAI(model="gpt-3.5-turbo-0125", api_key=open_api_key_token,max_tokens=150,temperature=0.2)
self.memory = ConversationSummaryMemory(llm=self.llm, return_messages=True)
# Tools setup
self.tools = [
Tool(
name="DatabaseQuery",
func=self.database_tool,
description="Queries the SQL database using dynamically generated SQL queries based on user questions. Aimed to retrieve structured data like counts, specific records, or summaries from predefined schemas.",
tool_choice="required"
),
Tool(
name="DocumentData",
func=self.document_data_tool,
description="Searches through indexed documents to find relevant information based on user queries. Handles unstructured data from various document formats like PDF, DOCX, or TXT files.",
tool_choice="required"
),
]
# Agent setup
prompt_template = self.setup_prompt()
self.agent = create_tool_calling_agent(self.llm.bind(memory=self.memory), self.tools, prompt_template)
self.agent_executor = AgentExecutor(agent=self.agent, tools=self.tools, memory=self.memory, verbose=True)
def setup_prompt(self):
prompt_template = f"""
You are an assistant that helps with database queries and document retrieval.
Please base your responses strictly on available data and avoid assumptions.
If the question pertains to numerical data or structured queries, use the DatabaseQuery tool.
If the question relates to content within various documents, use the DocumentData tool.
Question: {{input}}
{{agent_scratchpad}}
"""
return ChatPromptTemplate.from_template(prompt_template)
def database_tool(self, question):
sql_query = self.generate_sql_query(question)
return self.run_query(sql_query)
def get_schema(self,_):
# print(self.db.get_table_info())
return self.db.get_table_info()
def generate_sql_query(self, question):
schema = self.get_schema(None) # Get the schema using the function
template_query_generation = """Generate a SQL query to answer the user's question based on the available database schema.
{schema}
Question: {question}
SQL Query:"""
prompt_query_generation = ChatPromptTemplate.from_template(template_query_generation)
# Correctly setting up the initial data dictionary for the chain
input_data = {'question': question}
# Setup the chain correctly
sql_chain = (RunnablePassthrough.assign(schema=self.get_schema)
| prompt_query_generation
| self.llm.bind(stop="\nSQL Result:")
| StrOutputParser())
# Make sure to invoke with an empty dictionary if all needed data is already assigned
return sql_chain.invoke(input_data)
def run_query(self, query):
try:
logging.info(f"Executing SQL query: {query}")
result = self.db.run(query)
logging.info(f"Query successful: {result}")
return result
except Exception as e:
logging.error(f"Error executing query: {query}, Error: {str(e)}")
return None
def document_data_tool(self, query):
try:
logging.info(f"Searching documents for query: {query}")
embeddings = OpenAIEmbeddings(api_key=open_api_key_token)
index_paths = self.find_index_for_document(query)
responses = []
for index_path in index_paths:
vector_store = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
response = self.query_vector_store(vector_store, query)
responses.append(response)
logging.info(f"Document search results: {responses}")
return "\n".join(responses)
except Exception as e:
logging.error(f"Error in document data tool for query: {query}, Error: {str(e)}")
return "Error processing document query."
def find_index_for_document(self, query):
base_path = os.getenv('VECTOR_DB_PATH')
# document_hint = self.extract_document_hint(query)
index_paths = []
for root, dirs, files in os.walk(base_path):
for dir in dirs:
if 'index.faiss' in os.listdir(os.path.join(root, dir)):
index_paths.append(os.path.join(root, dir, ''))
return index_paths
def query_vector_store(self, vector_store, query):
docs = vector_store.similarity_search(query)
return '\n\n'.join([doc.page_content for doc in docs])
def answer_question(self, user_question):
try:
logging.info(f"Received question: {user_question}")
response = self.agent_executor.invoke({"input": user_question})
output_response = response.get("output", "No valid response generated.")
logging.info(f"Response generated: {output_response}")
return output_response
except Exception as e:
logging.error(f"Error processing question: {user_question}, Error: {str(e)}")
return f"An error occurred: {str(e)}"
|