Spaces:
Sleeping
Sleeping
updated
Browse files- app/main.py +93 -116
app/main.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
-
|
|
|
|
|
|
|
4 |
from langchain_openai import ChatOpenAI
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
@@ -8,63 +11,24 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
8 |
from langchain.chains import create_retrieval_chain
|
9 |
from langchain_community.vectorstores import FAISS
|
10 |
from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
|
11 |
-
from fastapi.middleware.cors import CORSMiddleware
|
12 |
-
from fastapi import FastAPI
|
13 |
-
from pydantic import BaseModel
|
14 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
15 |
-
import nltk # Importing NLTK
|
16 |
-
import time
|
17 |
-
|
18 |
-
import os
|
19 |
-
import nltk
|
20 |
-
|
21 |
-
# Set writable paths for cache and data
|
22 |
-
cache_dir = '/tmp'
|
23 |
-
nltk_data_path = os.path.join(cache_dir, 'nltk_data')
|
24 |
-
|
25 |
-
# Configure NLTK and other library paths
|
26 |
-
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, 'transformers_cache')
|
27 |
-
os.environ['HF_HOME'] = os.path.join(cache_dir, 'huggingface')
|
28 |
-
os.environ['XDG_CACHE_HOME'] = cache_dir
|
29 |
-
|
30 |
-
# Add NLTK data path
|
31 |
-
nltk.data.path.append(nltk_data_path)
|
32 |
-
|
33 |
-
# Ensure the directory exists
|
34 |
-
try:
|
35 |
-
os.makedirs(nltk_data_path, exist_ok=True)
|
36 |
-
except OSError as e:
|
37 |
-
print(f"Error creating directory {nltk_data_path}: {e}")
|
38 |
-
raise
|
39 |
-
|
40 |
-
# Download required NLTK resources
|
41 |
-
try:
|
42 |
-
nltk.download('punkt', download_dir=nltk_data_path)
|
43 |
-
print("NLTK 'punkt' resource downloaded successfully.")
|
44 |
-
except Exception as e:
|
45 |
-
print(f"Error downloading NLTK resources: {e}")
|
46 |
-
raise
|
47 |
-
|
48 |
-
|
49 |
|
50 |
|
|
|
51 |
def clean_response(response):
|
52 |
-
|
|
|
53 |
cleaned = response.strip()
|
54 |
-
|
55 |
-
# Remove any enclosing quotation marks
|
56 |
cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
|
57 |
-
|
58 |
-
# Replace multiple newlines with a single newline
|
59 |
cleaned = re.sub(r'\n+', '\n', cleaned)
|
60 |
-
|
61 |
-
# Remove any remaining '\n' characters
|
62 |
cleaned = cleaned.replace('\\n', '')
|
63 |
-
|
64 |
return cleaned
|
65 |
|
|
|
|
|
66 |
app = FastAPI()
|
67 |
|
|
|
68 |
app.add_middleware(
|
69 |
CORSMiddleware,
|
70 |
allow_origins=["*"],
|
@@ -73,103 +37,116 @@ app.add_middleware(
|
|
73 |
allow_headers=["*"],
|
74 |
)
|
75 |
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
llm = ChatOpenAI(
|
78 |
api_key=openai_api_key,
|
79 |
-
model_name="gpt-4-turbo-preview", #
|
80 |
-
temperature=0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
)
|
82 |
|
|
|
|
|
83 |
@app.get("/")
|
84 |
def read_root():
|
85 |
-
return {"
|
|
|
86 |
|
|
|
87 |
class Query(BaseModel):
|
88 |
query_text: str
|
89 |
|
90 |
-
prompt = ChatPromptTemplate.from_template(
|
91 |
-
"""
|
92 |
-
You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
|
93 |
-
You may elaborate on your answers slightly to provide more information, but avoid sounding boastful or exaggerating. Stay focused on the context provided.
|
94 |
-
If the query is not related to TIET or falls outside the context of education, respond with:
|
95 |
-
"Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
|
96 |
-
For more information, please contact at our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu
|
97 |
-
<context>
|
98 |
-
{context}
|
99 |
-
</context>
|
100 |
-
Question: {input}
|
101 |
-
"""
|
102 |
-
)
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
def vector_embedding():
|
105 |
try:
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
return {"response": "Error: Data file not found"}
|
110 |
|
111 |
-
|
|
|
112 |
documents = loader.load()
|
113 |
-
|
114 |
-
print(f"Loaded document: {file_path}")
|
115 |
|
116 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
117 |
chunks = text_splitter.split_documents(documents)
|
118 |
-
|
119 |
print(f"Created {len(chunks)} chunks.")
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
db = FAISS.from_documents(chunks, model_norm)
|
126 |
-
db.save_local("./vectors_db")
|
127 |
-
|
128 |
print("Vector store created and saved successfully.")
|
129 |
-
return {"response": "Vector Store DB
|
130 |
-
|
131 |
except Exception as e:
|
132 |
-
print(f"
|
133 |
-
|
|
|
134 |
|
|
|
135 |
def get_embeddings():
|
136 |
-
model_name = "BAAI/bge-base-en"
|
137 |
encode_kwargs = {'normalize_embeddings': True}
|
138 |
-
|
139 |
-
return model_norm
|
140 |
|
141 |
-
@app.post("/chat") # Changed from /anthropic to /chat
|
142 |
-
def read_item(query: Query):
|
143 |
-
try:
|
144 |
-
embeddings = get_embeddings()
|
145 |
-
vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
|
146 |
-
except Exception as e:
|
147 |
-
print(f"Error loading vector store: {str(e)}")
|
148 |
-
return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
|
149 |
-
|
150 |
-
prompt1 = query.query_text
|
151 |
-
if prompt1:
|
152 |
-
start = time.process_time()
|
153 |
-
document_chain = create_stuff_documents_chain(llm, prompt)
|
154 |
-
retriever = vectors.as_retriever()
|
155 |
-
retrieval_chain = create_retrieval_chain(retriever, document_chain)
|
156 |
-
response = retrieval_chain.invoke({'input': prompt1})
|
157 |
-
print("Response time:", time.process_time() - start)
|
158 |
-
|
159 |
-
# Apply the cleaning function to the response
|
160 |
-
cleaned_response = clean_response(response['answer'])
|
161 |
-
|
162 |
-
# For debugging, print the cleaned response
|
163 |
-
print("Cleaned response:", repr(cleaned_response))
|
164 |
-
|
165 |
-
return cleaned_response
|
166 |
-
else:
|
167 |
-
return "No Query Found"
|
168 |
-
|
169 |
-
@app.get("/setup")
|
170 |
-
def setup():
|
171 |
-
return vector_embedding()
|
172 |
|
|
|
173 |
if __name__ == "__main__":
|
174 |
import uvicorn
|
175 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
import os
|
2 |
import re
|
3 |
+
import time
|
4 |
+
from fastapi import FastAPI, HTTPException
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from pydantic import BaseModel
|
7 |
from langchain_openai import ChatOpenAI
|
8 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
9 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
|
11 |
from langchain.chains import create_retrieval_chain
|
12 |
from langchain_community.vectorstores import FAISS
|
13 |
from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
|
|
|
|
|
|
|
14 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
+
# Utility function to clean the response
|
18 |
def clean_response(response):
|
19 |
+
if not response:
|
20 |
+
return "Sorry, I couldn't generate a response."
|
21 |
cleaned = response.strip()
|
|
|
|
|
22 |
cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
|
|
|
|
|
23 |
cleaned = re.sub(r'\n+', '\n', cleaned)
|
|
|
|
|
24 |
cleaned = cleaned.replace('\\n', '')
|
|
|
25 |
return cleaned
|
26 |
|
27 |
+
|
28 |
+
# Initialize FastAPI app
|
29 |
app = FastAPI()
|
30 |
|
31 |
+
# CORS Middleware setup
|
32 |
app.add_middleware(
|
33 |
CORSMiddleware,
|
34 |
allow_origins=["*"],
|
|
|
37 |
allow_headers=["*"],
|
38 |
)
|
39 |
|
40 |
+
# Global Variables
|
41 |
+
openai_api_key = os.getenv('OPENAI_API_KEY') # Ensure this is set in your environment
|
42 |
+
VECTOR_DB_PATH = "./vectors_db"
|
43 |
+
DATA_FILE_PATH = "./data/Data.docx"
|
44 |
+
MODEL_NAME = "BAAI/bge-base-en"
|
45 |
+
|
46 |
+
# Initialize OpenAI LLM
|
47 |
llm = ChatOpenAI(
|
48 |
api_key=openai_api_key,
|
49 |
+
model_name="gpt-4-turbo-preview", # Use "gpt-3.5-turbo" for cost efficiency if required
|
50 |
+
temperature=0.7,
|
51 |
+
)
|
52 |
+
|
53 |
+
# Prompt template
|
54 |
+
prompt = ChatPromptTemplate.from_template(
|
55 |
+
"""
|
56 |
+
You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
|
57 |
+
If the query is not related to TIET or falls outside the context of education, respond with:
|
58 |
+
"Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
|
59 |
+
For more information, please contact at our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu
|
60 |
+
<context>
|
61 |
+
{context}
|
62 |
+
</context>
|
63 |
+
Question: {input}
|
64 |
+
"""
|
65 |
)
|
66 |
|
67 |
+
|
68 |
+
# Route: Home
|
69 |
@app.get("/")
|
70 |
def read_root():
|
71 |
+
return {"message": "Welcome to the ThaparGPT API!"}
|
72 |
+
|
73 |
|
74 |
+
# Route: Chat Endpoint
|
75 |
class Query(BaseModel):
|
76 |
query_text: str
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
@app.post("/chat")
|
80 |
+
def chat(query: Query):
|
81 |
+
try:
|
82 |
+
# Load the vector store
|
83 |
+
embeddings = get_embeddings()
|
84 |
+
vectors = FAISS.load_local(VECTOR_DB_PATH, embeddings, allow_dangerous_deserialization=True)
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error loading vector store: {str(e)}")
|
87 |
+
raise HTTPException(status_code=500, detail="Vector Store not found or loading failed. Please run /setup first.")
|
88 |
+
|
89 |
+
# Retrieve and process the query
|
90 |
+
query_text = query.query_text
|
91 |
+
if query_text:
|
92 |
+
start_time = time.process_time()
|
93 |
+
document_chain = create_stuff_documents_chain(llm, prompt)
|
94 |
+
retriever = vectors.as_retriever()
|
95 |
+
retrieval_chain = create_retrieval_chain(retriever, document_chain)
|
96 |
+
|
97 |
+
try:
|
98 |
+
response = retrieval_chain.invoke({'input': query_text})
|
99 |
+
except Exception as e:
|
100 |
+
print(f"Error during query processing: {str(e)}")
|
101 |
+
raise HTTPException(status_code=500, detail="Error processing the query.")
|
102 |
+
|
103 |
+
print("Response time:", time.process_time() - start_time)
|
104 |
+
cleaned_response = clean_response(response.get('answer', ''))
|
105 |
+
return {"response": cleaned_response}
|
106 |
+
else:
|
107 |
+
raise HTTPException(status_code=400, detail="No query found in the request.")
|
108 |
+
|
109 |
+
|
110 |
+
# Route: Setup Endpoint
|
111 |
+
@app.get("/setup")
|
112 |
+
def setup():
|
113 |
+
return vector_embedding()
|
114 |
+
|
115 |
+
|
116 |
+
# Utility: Create Vector Embeddings
|
117 |
def vector_embedding():
|
118 |
try:
|
119 |
+
if not os.path.exists(DATA_FILE_PATH):
|
120 |
+
print(f"The file {DATA_FILE_PATH} does not exist.")
|
121 |
+
raise HTTPException(status_code=404, detail="Data file not found.")
|
|
|
122 |
|
123 |
+
# Load and split document
|
124 |
+
loader = DocxLoader(DATA_FILE_PATH)
|
125 |
documents = loader.load()
|
126 |
+
print(f"Loaded document: {DATA_FILE_PATH}")
|
|
|
127 |
|
128 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
129 |
chunks = text_splitter.split_documents(documents)
|
|
|
130 |
print(f"Created {len(chunks)} chunks.")
|
131 |
|
132 |
+
# Create vector store
|
133 |
+
embeddings = get_embeddings()
|
134 |
+
db = FAISS.from_documents(chunks, embeddings)
|
135 |
+
db.save_local(VECTOR_DB_PATH)
|
|
|
|
|
|
|
136 |
print("Vector store created and saved successfully.")
|
137 |
+
return {"response": "Vector Store DB is ready."}
|
|
|
138 |
except Exception as e:
|
139 |
+
print(f"Error during setup: {str(e)}")
|
140 |
+
raise HTTPException(status_code=500, detail=f"Error during setup: {str(e)}")
|
141 |
+
|
142 |
|
143 |
+
# Utility: Load Embedding Model
|
144 |
def get_embeddings():
|
|
|
145 |
encode_kwargs = {'normalize_embeddings': True}
|
146 |
+
return HuggingFaceBgeEmbeddings(model_name=MODEL_NAME, encode_kwargs=encode_kwargs)
|
|
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
# Main entry point
|
150 |
if __name__ == "__main__":
|
151 |
import uvicorn
|
152 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|