import os from dotenv import load_dotenv import asyncio from flask import Flask, request, render_template from flask_cors import CORS from flask_socketio import SocketIO, emit, join_room, leave_room from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory from pinecone import Pinecone from pinecone_text.sparse import BM25Encoder from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.retrievers import PineconeHybridSearchRetriever from langchain_groq import ChatGroq # Load environment variables load_dotenv(".env") USER_AGENT = os.getenv("USER_AGENT") GROQ_API_KEY = os.getenv("GROQ_API_KEY") SECRET_KEY = os.getenv("SECRET_KEY") PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") SESSION_ID_DEFAULT = "abc123" # Set environment variables os.environ['USER_AGENT'] = USER_AGENT os.environ["GROQ_API_KEY"] = GROQ_API_KEY os.environ["TOKENIZERS_PARALLELISM"] = 'true' # Initialize Flask app and SocketIO with CORS app = Flask(__name__) CORS(app) socketio = SocketIO(app, cors_allowed_origins="*") app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS app.config['SESSION_COOKIE_HTTPONLY'] = True app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' app.config['SECRET_KEY'] = SECRET_KEY # Function to initialize Pinecone connection def initialize_pinecone(index_name: str): try: pc = Pinecone(api_key=PINECONE_API_KEY) return pc.Index(index_name) except Exception as e: print(f"Error initializing Pinecone: {e}") raise # Initialize Pinecone index and BM25 encoder pinecone_index = initialize_pinecone("traveler-demo-website-vectorstore") bm25 = BM25Encoder().load("./bm25_traveler_website.json") # Initialize models and retriever embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True}) retriever = PineconeHybridSearchRetriever( embeddings=embed_model, sparse_encoder=bm25, index=pinecone_index, top_k=20, alpha=0.5 ) # Initialize LLM llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, max_tokens=1024, max_retries=2) # Contextualization prompt and retriever contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is. """ contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}") ] ) history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt) # QA system prompt and chain qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following context to answer questions effectively. \ If you don't know the answer, simply state that you don't know. \ Provide answers in proper HTML format and keep them concise. \ When responding to queries, follow these guidelines: \ 1. Provide Clear Answers: \ - Ensure the response directly addresses the query with accurate and relevant information.\ 2. Include Detailed References: \ - Links to Sources: Include URLs to credible sources where users can verify information or explore further. \ - Reference Sites: Mention specific websites or platforms that offer additional information. \ - Downloadable Materials: Provide links to any relevant downloadable resources if applicable. \ 3. Formatting for Readability: \ - The answer should be in a proper HTML format with appropriate tags. \ - Use bullet points or numbered lists where applicable to present information clearly. \ - Highlight key details using bold or italics. \ - Provide proper and meaningful abbreviations for urls. Do not include naked urls. \ 4. Organize Content Logically: \ - Structure the content in a logical order, ensuring easy navigation and understanding for the user. \ {context} """ qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}") ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) # Retrieval and Generative (RAG) Chain rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) # Chat message history storage store = {} def clean_temporary_data(): store.clear() def get_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in store: store[session_id] = ChatMessageHistory() return store[session_id] # Conversational RAG chain with message history conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) # Function to handle WebSocket connection @socketio.on('connect') def handle_connect(): print(f"Client connected: {request.sid}") emit('connection_response', {'message': 'Connected successfully.'}) # Function to handle WebSocket disconnection @socketio.on('disconnect') def handle_disconnect(): print(f"Client disconnected: {request.sid}") clean_temporary_data() # Function to handle WebSocket messages @socketio.on('message') def handle_message(data): question = data.get('question') session_id = data.get('session_id', SESSION_ID_DEFAULT) chain = conversational_rag_chain.pick("answer") try: for chunk in chain.stream( {"input": question}, config={"configurable": {"session_id": session_id}}, ): emit('response', chunk, room=request.sid) except Exception as e: print(f"Error during message handling: {e}") emit('response', {"error": "An error occurred while processing your request."}, room=request.sid) # Home route @app.route("/") def index_view(): return render_template('chat.html') # Main function to run the app if __name__ == '__main__': socketio.run(app, debug=False)