0504ankitsharma commited on
Commit
36a19d8
·
verified ·
1 Parent(s): 470d648
Files changed (1) hide show
  1. app/main.py +93 -116
app/main.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
  import re
3
- from openai import OpenAI
 
 
 
4
  from langchain_openai import ChatOpenAI
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.chains.combine_documents import create_stuff_documents_chain
@@ -8,63 +11,24 @@ from langchain_core.prompts import ChatPromptTemplate
8
  from langchain.chains import create_retrieval_chain
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
- from fastapi.middleware.cors import CORSMiddleware
12
- from fastapi import FastAPI
13
- from pydantic import BaseModel
14
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
15
- import nltk # Importing NLTK
16
- import time
17
-
18
- import os
19
- import nltk
20
-
21
- # Set writable paths for cache and data
22
- cache_dir = '/tmp'
23
- nltk_data_path = os.path.join(cache_dir, 'nltk_data')
24
-
25
- # Configure NLTK and other library paths
26
- os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, 'transformers_cache')
27
- os.environ['HF_HOME'] = os.path.join(cache_dir, 'huggingface')
28
- os.environ['XDG_CACHE_HOME'] = cache_dir
29
-
30
- # Add NLTK data path
31
- nltk.data.path.append(nltk_data_path)
32
-
33
- # Ensure the directory exists
34
- try:
35
- os.makedirs(nltk_data_path, exist_ok=True)
36
- except OSError as e:
37
- print(f"Error creating directory {nltk_data_path}: {e}")
38
- raise
39
-
40
- # Download required NLTK resources
41
- try:
42
- nltk.download('punkt', download_dir=nltk_data_path)
43
- print("NLTK 'punkt' resource downloaded successfully.")
44
- except Exception as e:
45
- print(f"Error downloading NLTK resources: {e}")
46
- raise
47
-
48
-
49
 
50
 
 
51
  def clean_response(response):
52
- # Remove any leading/trailing whitespace, including newlines
 
53
  cleaned = response.strip()
54
-
55
- # Remove any enclosing quotation marks
56
  cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
57
-
58
- # Replace multiple newlines with a single newline
59
  cleaned = re.sub(r'\n+', '\n', cleaned)
60
-
61
- # Remove any remaining '\n' characters
62
  cleaned = cleaned.replace('\\n', '')
63
-
64
  return cleaned
65
 
 
 
66
  app = FastAPI()
67
 
 
68
  app.add_middleware(
69
  CORSMiddleware,
70
  allow_origins=["*"],
@@ -73,103 +37,116 @@ app.add_middleware(
73
  allow_headers=["*"],
74
  )
75
 
76
- openai_api_key = os.environ.get('OPENAI_API_KEY')
 
 
 
 
 
 
77
  llm = ChatOpenAI(
78
  api_key=openai_api_key,
79
- model_name="gpt-4-turbo-preview", # or "gpt-3.5-turbo" for a more economical option
80
- temperature=0.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
 
 
 
83
  @app.get("/")
84
  def read_root():
85
- return {"Hello": "World"}
 
86
 
 
87
  class Query(BaseModel):
88
  query_text: str
89
 
90
- prompt = ChatPromptTemplate.from_template(
91
- """
92
- You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
93
- You may elaborate on your answers slightly to provide more information, but avoid sounding boastful or exaggerating. Stay focused on the context provided.
94
- If the query is not related to TIET or falls outside the context of education, respond with:
95
- "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
96
- For more information, please contact at our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu
97
- <context>
98
- {context}
99
- </context>
100
- Question: {input}
101
- """
102
- )
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def vector_embedding():
105
  try:
106
- file_path = "./data/Data.docx"
107
- if not os.path.exists(file_path):
108
- print(f"The file {file_path} does not exist.")
109
- return {"response": "Error: Data file not found"}
110
 
111
- loader = DocxLoader(file_path)
 
112
  documents = loader.load()
113
-
114
- print(f"Loaded document: {file_path}")
115
 
116
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
117
  chunks = text_splitter.split_documents(documents)
118
-
119
  print(f"Created {len(chunks)} chunks.")
120
 
121
- model_name = "BAAI/bge-base-en"
122
- encode_kwargs = {'normalize_embeddings': True}
123
- model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
124
-
125
- db = FAISS.from_documents(chunks, model_norm)
126
- db.save_local("./vectors_db")
127
-
128
  print("Vector store created and saved successfully.")
129
- return {"response": "Vector Store DB Is Ready"}
130
-
131
  except Exception as e:
132
- print(f"An error occurred: {str(e)}")
133
- return {"response": f"Error: {str(e)}"}
 
134
 
 
135
  def get_embeddings():
136
- model_name = "BAAI/bge-base-en"
137
  encode_kwargs = {'normalize_embeddings': True}
138
- model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
139
- return model_norm
140
 
141
- @app.post("/chat") # Changed from /anthropic to /chat
142
- def read_item(query: Query):
143
- try:
144
- embeddings = get_embeddings()
145
- vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
146
- except Exception as e:
147
- print(f"Error loading vector store: {str(e)}")
148
- return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
149
-
150
- prompt1 = query.query_text
151
- if prompt1:
152
- start = time.process_time()
153
- document_chain = create_stuff_documents_chain(llm, prompt)
154
- retriever = vectors.as_retriever()
155
- retrieval_chain = create_retrieval_chain(retriever, document_chain)
156
- response = retrieval_chain.invoke({'input': prompt1})
157
- print("Response time:", time.process_time() - start)
158
-
159
- # Apply the cleaning function to the response
160
- cleaned_response = clean_response(response['answer'])
161
-
162
- # For debugging, print the cleaned response
163
- print("Cleaned response:", repr(cleaned_response))
164
-
165
- return cleaned_response
166
- else:
167
- return "No Query Found"
168
-
169
- @app.get("/setup")
170
- def setup():
171
- return vector_embedding()
172
 
 
173
  if __name__ == "__main__":
174
  import uvicorn
175
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
  import re
3
+ import time
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel
7
  from langchain_openai import ChatOpenAI
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
11
  from langchain.chains import create_retrieval_chain
12
  from langchain_community.vectorstores import FAISS
13
  from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
 
 
 
14
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
+ # Utility function to clean the response
18
  def clean_response(response):
19
+ if not response:
20
+ return "Sorry, I couldn't generate a response."
21
  cleaned = response.strip()
 
 
22
  cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
 
 
23
  cleaned = re.sub(r'\n+', '\n', cleaned)
 
 
24
  cleaned = cleaned.replace('\\n', '')
 
25
  return cleaned
26
 
27
+
28
+ # Initialize FastAPI app
29
  app = FastAPI()
30
 
31
+ # CORS Middleware setup
32
  app.add_middleware(
33
  CORSMiddleware,
34
  allow_origins=["*"],
 
37
  allow_headers=["*"],
38
  )
39
 
40
+ # Global Variables
41
+ openai_api_key = os.getenv('OPENAI_API_KEY') # Ensure this is set in your environment
42
+ VECTOR_DB_PATH = "./vectors_db"
43
+ DATA_FILE_PATH = "./data/Data.docx"
44
+ MODEL_NAME = "BAAI/bge-base-en"
45
+
46
+ # Initialize OpenAI LLM
47
  llm = ChatOpenAI(
48
  api_key=openai_api_key,
49
+ model_name="gpt-4-turbo-preview", # Use "gpt-3.5-turbo" for cost efficiency if required
50
+ temperature=0.7,
51
+ )
52
+
53
+ # Prompt template
54
+ prompt = ChatPromptTemplate.from_template(
55
+ """
56
+ You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
57
+ If the query is not related to TIET or falls outside the context of education, respond with:
58
+ "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
59
+ For more information, please contact at our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu
60
+ <context>
61
+ {context}
62
+ </context>
63
+ Question: {input}
64
+ """
65
  )
66
 
67
+
68
+ # Route: Home
69
  @app.get("/")
70
  def read_root():
71
+ return {"message": "Welcome to the ThaparGPT API!"}
72
+
73
 
74
+ # Route: Chat Endpoint
75
  class Query(BaseModel):
76
  query_text: str
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ @app.post("/chat")
80
+ def chat(query: Query):
81
+ try:
82
+ # Load the vector store
83
+ embeddings = get_embeddings()
84
+ vectors = FAISS.load_local(VECTOR_DB_PATH, embeddings, allow_dangerous_deserialization=True)
85
+ except Exception as e:
86
+ print(f"Error loading vector store: {str(e)}")
87
+ raise HTTPException(status_code=500, detail="Vector Store not found or loading failed. Please run /setup first.")
88
+
89
+ # Retrieve and process the query
90
+ query_text = query.query_text
91
+ if query_text:
92
+ start_time = time.process_time()
93
+ document_chain = create_stuff_documents_chain(llm, prompt)
94
+ retriever = vectors.as_retriever()
95
+ retrieval_chain = create_retrieval_chain(retriever, document_chain)
96
+
97
+ try:
98
+ response = retrieval_chain.invoke({'input': query_text})
99
+ except Exception as e:
100
+ print(f"Error during query processing: {str(e)}")
101
+ raise HTTPException(status_code=500, detail="Error processing the query.")
102
+
103
+ print("Response time:", time.process_time() - start_time)
104
+ cleaned_response = clean_response(response.get('answer', ''))
105
+ return {"response": cleaned_response}
106
+ else:
107
+ raise HTTPException(status_code=400, detail="No query found in the request.")
108
+
109
+
110
+ # Route: Setup Endpoint
111
+ @app.get("/setup")
112
+ def setup():
113
+ return vector_embedding()
114
+
115
+
116
+ # Utility: Create Vector Embeddings
117
  def vector_embedding():
118
  try:
119
+ if not os.path.exists(DATA_FILE_PATH):
120
+ print(f"The file {DATA_FILE_PATH} does not exist.")
121
+ raise HTTPException(status_code=404, detail="Data file not found.")
 
122
 
123
+ # Load and split document
124
+ loader = DocxLoader(DATA_FILE_PATH)
125
  documents = loader.load()
126
+ print(f"Loaded document: {DATA_FILE_PATH}")
 
127
 
128
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
129
  chunks = text_splitter.split_documents(documents)
 
130
  print(f"Created {len(chunks)} chunks.")
131
 
132
+ # Create vector store
133
+ embeddings = get_embeddings()
134
+ db = FAISS.from_documents(chunks, embeddings)
135
+ db.save_local(VECTOR_DB_PATH)
 
 
 
136
  print("Vector store created and saved successfully.")
137
+ return {"response": "Vector Store DB is ready."}
 
138
  except Exception as e:
139
+ print(f"Error during setup: {str(e)}")
140
+ raise HTTPException(status_code=500, detail=f"Error during setup: {str(e)}")
141
+
142
 
143
+ # Utility: Load Embedding Model
144
  def get_embeddings():
 
145
  encode_kwargs = {'normalize_embeddings': True}
146
+ return HuggingFaceBgeEmbeddings(model_name=MODEL_NAME, encode_kwargs=encode_kwargs)
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # Main entry point
150
  if __name__ == "__main__":
151
  import uvicorn
152
  uvicorn.run(app, host="0.0.0.0", port=8000)