Update app.py
Browse files
app.py
CHANGED
@@ -11,31 +11,32 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
11 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
13 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
14 |
-
from
|
15 |
-
from
|
16 |
-
from langchain_core.runnables
|
17 |
from pinecone import Pinecone
|
18 |
from pinecone_text.sparse import BM25Encoder
|
19 |
-
from
|
20 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
21 |
from langchain.retrievers import ContextualCompressionRetriever
|
22 |
-
from langchain_community.chat_models import
|
23 |
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
24 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
25 |
-
from
|
26 |
import re
|
|
|
27 |
|
28 |
# Load environment variables
|
29 |
load_dotenv(".env")
|
30 |
USER_AGENT = os.getenv("USER_AGENT")
|
31 |
-
|
32 |
SECRET_KEY = os.getenv("SECRET_KEY")
|
33 |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
34 |
SESSION_ID_DEFAULT = "abc123"
|
35 |
|
36 |
# Set environment variables
|
37 |
os.environ['USER_AGENT'] = USER_AGENT
|
38 |
-
os.environ["
|
39 |
os.environ["TOKENIZERS_PARALLELISM"] = 'true'
|
40 |
|
41 |
# Initialize FastAPI app and CORS
|
@@ -74,6 +75,7 @@ bm25 = BM25Encoder().load("./mbzuai-policies.json")
|
|
74 |
|
75 |
# Initialize models and retriever
|
76 |
embed_model = HuggingFaceEmbeddings(model_name="jinaai/jina-embeddings-v3", model_kwargs={"trust_remote_code":True})
|
|
|
77 |
retriever = PineconeHybridSearchRetriever(
|
78 |
embeddings=embed_model,
|
79 |
sparse_encoder=bm25,
|
@@ -83,11 +85,11 @@ retriever = PineconeHybridSearchRetriever(
|
|
83 |
)
|
84 |
|
85 |
# Initialize LLM
|
86 |
-
llm =
|
87 |
|
88 |
# Initialize Reranker
|
89 |
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
|
90 |
-
compressor = CrossEncoderReranker(model=model, top_n=
|
91 |
|
92 |
compression_retriever = ContextualCompressionRetriever(
|
93 |
base_compressor=compressor, base_retriever=retriever
|
@@ -122,15 +124,15 @@ When responding to queries, follow these guidelines:
|
|
122 |
|
123 |
2. Formatting for Readability:
|
124 |
- Provide the entire response in proper markdown format.
|
125 |
-
- Use structured
|
126 |
-
- Use
|
127 |
|
128 |
3. Proper Citations:
|
129 |
-
- ALWAYS USE INLINE CITATIONS with
|
130 |
-
- The inline citations should be in the format [1], [2], etc., in the response with links to reference sources.
|
|
|
131 |
|
132 |
FOLLOW ALL THE GIVEN INSTRUCTIONS, FAILURE TO DO SO WILL RESULT IN TERMINATION OF THE CHAT.
|
133 |
-
|
134 |
{context}
|
135 |
"""
|
136 |
qa_prompt = ChatPromptTemplate.from_messages(
|
@@ -165,7 +167,6 @@ conversational_rag_chain = RunnableWithMessageHistory(
|
|
165 |
output_messages_key="answer",
|
166 |
)
|
167 |
|
168 |
-
|
169 |
# WebSocket endpoint with streaming
|
170 |
@app.websocket("/ws")
|
171 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
11 |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
13 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
14 |
+
from langchain.schema import BaseChatMessageHistory
|
15 |
+
from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
|
16 |
+
from langchain_core.runnables import RunnableWithMessageHistory
|
17 |
from pinecone import Pinecone
|
18 |
from pinecone_text.sparse import BM25Encoder
|
19 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
20 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
21 |
from langchain.retrievers import ContextualCompressionRetriever
|
22 |
+
from langchain_community.chat_models import ChatOpenAI
|
23 |
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
24 |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
25 |
+
from langchain.prompts import PromptTemplate
|
26 |
import re
|
27 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
28 |
|
29 |
# Load environment variables
|
30 |
load_dotenv(".env")
|
31 |
USER_AGENT = os.getenv("USER_AGENT")
|
32 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
33 |
SECRET_KEY = os.getenv("SECRET_KEY")
|
34 |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
35 |
SESSION_ID_DEFAULT = "abc123"
|
36 |
|
37 |
# Set environment variables
|
38 |
os.environ['USER_AGENT'] = USER_AGENT
|
39 |
+
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
40 |
os.environ["TOKENIZERS_PARALLELISM"] = 'true'
|
41 |
|
42 |
# Initialize FastAPI app and CORS
|
|
|
75 |
|
76 |
# Initialize models and retriever
|
77 |
embed_model = HuggingFaceEmbeddings(model_name="jinaai/jina-embeddings-v3", model_kwargs={"trust_remote_code":True})
|
78 |
+
|
79 |
retriever = PineconeHybridSearchRetriever(
|
80 |
embeddings=embed_model,
|
81 |
sparse_encoder=bm25,
|
|
|
85 |
)
|
86 |
|
87 |
# Initialize LLM
|
88 |
+
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=512)
|
89 |
|
90 |
# Initialize Reranker
|
91 |
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
|
92 |
+
compressor = CrossEncoderReranker(model=model, top_n=10)
|
93 |
|
94 |
compression_retriever = ContextualCompressionRetriever(
|
95 |
base_compressor=compressor, base_retriever=retriever
|
|
|
124 |
|
125 |
2. Formatting for Readability:
|
126 |
- Provide the entire response in proper markdown format.
|
127 |
+
- Use structured Markdown elements such as headings, subheadings, lists, tables, and links.
|
128 |
+
- Use emphasis on headings, important texts, and phrases.
|
129 |
|
130 |
3. Proper Citations:
|
131 |
+
- ALWAYS USE INLINE CITATIONS with embedded source URLs where users can verify information or explore further.
|
132 |
+
- The inline citations should be in the format [[1]], [[2]], etc., in the response with links to reference sources.
|
133 |
+
- Then at the end of the response, list out the citations with their sources.
|
134 |
|
135 |
FOLLOW ALL THE GIVEN INSTRUCTIONS, FAILURE TO DO SO WILL RESULT IN TERMINATION OF THE CHAT.
|
|
|
136 |
{context}
|
137 |
"""
|
138 |
qa_prompt = ChatPromptTemplate.from_messages(
|
|
|
167 |
output_messages_key="answer",
|
168 |
)
|
169 |
|
|
|
170 |
# WebSocket endpoint with streaming
|
171 |
@app.websocket("/ws")
|
172 |
async def websocket_endpoint(websocket: WebSocket):
|