0504ankitsharma commited on
Commit
ac4e9ed
·
verified ·
1 Parent(s): 807315c

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +32 -56
app/main.py CHANGED
@@ -9,15 +9,11 @@ from langchain.chains import create_retrieval_chain
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
  from fastapi.middleware.cors import CORSMiddleware
12
- from fastapi import FastAPI, Depends
13
  from pydantic import BaseModel
14
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
15
- import nltk # Importing NLTK
16
  import time
17
- from typing import Dict, Optional
18
- from fastapi.sessions import SessionMiddleware
19
- from fastapi.requests import Request
20
- from fastapi.responses import JSONResponse
21
 
22
  # Set writable paths for cache and data
23
  cache_dir = '/tmp'
@@ -47,18 +43,10 @@ except Exception as e:
47
  raise
48
 
49
  def clean_response(response):
50
- # Remove any leading/trailing whitespace, including newlines
51
  cleaned = response.strip()
52
-
53
- # Remove any enclosing quotation marks
54
- cleaned = re.sub(r'^\"+|\"+$', '', cleaned)
55
-
56
- # Replace multiple newlines with a single newline
57
  cleaned = re.sub(r'\n+', '\n', cleaned)
58
-
59
- # Remove any remaining '\n' characters
60
  cleaned = cleaned.replace('\\n', '')
61
-
62
  return cleaned
63
 
64
  app = FastAPI()
@@ -71,9 +59,6 @@ app.add_middleware(
71
  allow_headers=["*"],
72
  )
73
 
74
- # Adding session middleware for contextual memory
75
- app.add_middleware(SessionMiddleware, secret_key="your-secret-key")
76
-
77
  openai_api_key = os.environ.get('OPENAI_API_KEY')
78
  llm = ChatOpenAI(
79
  api_key=openai_api_key,
@@ -81,39 +66,28 @@ llm = ChatOpenAI(
81
  temperature=0.7
82
  )
83
 
 
 
84
  @app.get("/")
85
  def read_root():
86
  return {"Hello": "World"}
87
 
88
  class Query(BaseModel):
 
89
  query_text: str
90
 
91
- # In-memory storage for contextual memory
92
- user_sessions: Dict[str, Dict[str, str]] = {}
93
-
94
- def get_user_context(request: Request):
95
- user_id = request.client.host
96
- if user_id not in user_sessions:
97
- user_sessions[user_id] = {}
98
- return user_id, user_sessions[user_id]
99
 
100
- prompt = ChatPromptTemplate.from_template(
101
- """
102
- You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
103
- You may elaborate on your answers slightly to provide more information, but avoid sounding boastful or exaggerating. Stay focused on the context provided.
104
 
105
- If the query is not related to TIET or falls outside the context of education, respond with:
106
- "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
107
- For more information, please contact at our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu
108
-
109
- Previous Context:
110
- {previous_context}
111
-
112
- <context>
113
- {context}
114
- </context>
115
- Question: {input}
116
- """
117
  )
118
 
119
  def vector_embedding():
@@ -130,16 +104,16 @@ def vector_embedding():
130
 
131
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
132
  chunks = text_splitter.split_documents(documents)
133
-
134
  print(f"Created {len(chunks)} chunks.")
135
 
136
  model_name = "BAAI/bge-base-en"
137
  encode_kwargs = {'normalize_embeddings': True}
138
  model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
139
-
140
  db = FAISS.from_documents(chunks, model_norm)
141
  db.save_local("./vectors_db")
142
-
143
  print("Vector store created and saved successfully.")
144
  return {"response": "Vector Store DB Is Ready"}
145
 
@@ -154,9 +128,12 @@ def get_embeddings():
154
  return model_norm
155
 
156
  @app.post("/chat")
157
- def read_item(query: Query, request: Request):
158
  try:
159
- user_id, user_context = get_user_context(request)
 
 
 
160
  embeddings = get_embeddings()
161
  vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
162
  except Exception as e:
@@ -166,22 +143,21 @@ def read_item(query: Query, request: Request):
166
  prompt1 = query.query_text
167
  if prompt1:
168
  start = time.process_time()
169
- document_chain = create_stuff_documents_chain(llm, prompt)
170
  retriever = vectors.as_retriever()
171
  retrieval_chain = create_retrieval_chain(retriever, document_chain)
172
 
173
- # Add previous context
174
- previous_context = user_context.get("context", "None")
175
- response = retrieval_chain.invoke({'input': prompt1, 'previous_context': previous_context})
176
- print("Response time:", time.process_time() - start)
177
 
178
- # Apply the cleaning function to the response
179
  cleaned_response = clean_response(response['answer'])
180
 
181
- # Update context
182
- user_context["context"] = cleaned_response
 
183
 
184
- print("Cleaned response:", repr(cleaned_response))
185
  return {"response": cleaned_response}
186
  else:
187
  return {"response": "No Query Found"}
 
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
  from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi import FastAPI, Request
13
  from pydantic import BaseModel
14
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
15
+ import nltk
16
  import time
 
 
 
 
17
 
18
  # Set writable paths for cache and data
19
  cache_dir = '/tmp'
 
43
  raise
44
 
45
  def clean_response(response):
 
46
  cleaned = response.strip()
47
+ cleaned = re.sub(r'^\"|\"$', '', cleaned)
 
 
 
 
48
  cleaned = re.sub(r'\n+', '\n', cleaned)
 
 
49
  cleaned = cleaned.replace('\\n', '')
 
50
  return cleaned
51
 
52
  app = FastAPI()
 
59
  allow_headers=["*"],
60
  )
61
 
 
 
 
62
  openai_api_key = os.environ.get('OPENAI_API_KEY')
63
  llm = ChatOpenAI(
64
  api_key=openai_api_key,
 
66
  temperature=0.7
67
  )
68
 
69
+ conversation_history = {} # Dictionary to maintain contextual memory
70
+
71
  @app.get("/")
72
  def read_root():
73
  return {"Hello": "World"}
74
 
75
  class Query(BaseModel):
76
+ session_id: str # Unique identifier for user session
77
  query_text: str
78
 
79
+ prompt_template = ChatPromptTemplate.from_template(
80
+ """
81
+ You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.'
 
 
 
 
 
82
 
83
+ If the query is not related to TIET or falls outside the context of education, respond with:
84
+ "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology. For more information, please contact at our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu"
 
 
85
 
86
+ <context>
87
+ {context}
88
+ </context>
89
+ Question: {input}
90
+ """
 
 
 
 
 
 
 
91
  )
92
 
93
  def vector_embedding():
 
104
 
105
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
106
  chunks = text_splitter.split_documents(documents)
107
+
108
  print(f"Created {len(chunks)} chunks.")
109
 
110
  model_name = "BAAI/bge-base-en"
111
  encode_kwargs = {'normalize_embeddings': True}
112
  model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
113
+
114
  db = FAISS.from_documents(chunks, model_norm)
115
  db.save_local("./vectors_db")
116
+
117
  print("Vector store created and saved successfully.")
118
  return {"response": "Vector Store DB Is Ready"}
119
 
 
128
  return model_norm
129
 
130
  @app.post("/chat")
131
+ def chat_endpoint(query: Query):
132
  try:
133
+ session_id = query.session_id
134
+ if session_id not in conversation_history:
135
+ conversation_history[session_id] = []
136
+
137
  embeddings = get_embeddings()
138
  vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
139
  except Exception as e:
 
143
  prompt1 = query.query_text
144
  if prompt1:
145
  start = time.process_time()
146
+ document_chain = create_stuff_documents_chain(llm, prompt_template)
147
  retriever = vectors.as_retriever()
148
  retrieval_chain = create_retrieval_chain(retriever, document_chain)
149
 
150
+ # Combine context from conversation history
151
+ context = "\n".join(conversation_history[session_id])
152
+ response = retrieval_chain.invoke({'input': prompt1, 'context': context})
 
153
 
 
154
  cleaned_response = clean_response(response['answer'])
155
 
156
+ # Update conversation history
157
+ conversation_history[session_id].append(f"User: {prompt1}")
158
+ conversation_history[session_id].append(f"Assistant: {cleaned_response}")
159
 
160
+ print("Response time:", time.process_time() - start)
161
  return {"response": cleaned_response}
162
  else:
163
  return {"response": "No Query Found"}