0504ankitsharma's picture
Update app/main.py
405e044 verified
import os
import re
from openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI
from pydantic import BaseModel
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import nltk
import time
# Set writable paths for cache and data
cache_dir = '/tmp'
writable_dir = os.path.join(cache_dir, 'vectors_db')
nltk_data_path = os.path.join(cache_dir, 'nltk_data')
# Configure NLTK and other library paths
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, 'transformers_cache')
os.environ['HF_HOME'] = os.path.join(cache_dir, 'huggingface')
os.environ['XDG_CACHE_HOME'] = cache_dir
# Add NLTK data path
nltk.data.path.append(nltk_data_path)
# Ensure the directories exist
os.makedirs(nltk_data_path, exist_ok=True)
os.makedirs(writable_dir, exist_ok=True)
# Download required NLTK resources
nltk.download('punkt', download_dir=nltk_data_path)
def clean_response(response):
# Remove any leading/trailing whitespace, including newlines
cleaned = response.strip()
# Remove any enclosing quotation marks
cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
# Replace multiple newlines with a single newline
cleaned = re.sub(r'\n+', '\n', cleaned)
# Remove any remaining '\n' characters
cleaned = cleaned.replace('\\n', '')
return cleaned
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
openai_api_key = os.environ.get('OPENAI_API_KEY')
llm = ChatOpenAI(
api_key=openai_api_key,
model_name="gpt-4-turbo-preview",
temperature=0.7,
max_tokens=200
)
@app.get("/")
def read_root():
return {"Hello": "World"}
class Query(BaseModel):
query_text: str
prompt = ChatPromptTemplate.from_template(
"""
You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET in a concise manner. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
If the query is not related to TIET or falls outside the context of education, respond with:
"Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
For more information, please contact our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu
<context>
{context}
</context>
Question: {input}
"""
)
def vector_embedding():
try:
file_path = "./data/Data.docx"
if not os.path.exists(file_path):
print(f"The file {file_path} does not exist.")
return {"response": "Error: Data file not found"}
loader = DocxLoader(file_path)
documents = loader.load()
print(f"Loaded document: {file_path}")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
chunks = text_splitter.split_documents(documents)
print(f"Created {len(chunks)} chunks.")
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
# Save FAISS vector store to a writable directory
db = FAISS.from_documents(chunks, model_norm)
db.save_local(writable_dir)
print(f"Vector store created and saved successfully to {writable_dir}.")
return {"response": "Vector Store DB Is Ready"}
except Exception as e:
print(f"An error occurred: {str(e)}")
return {"response": f"Error: {str(e)}"}
def get_embeddings():
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
return model_norm
@app.post("/chat")
def read_item(query: Query):
try:
embeddings = get_embeddings()
vectors = FAISS.load_local(writable_dir, embeddings, allow_dangerous_deserialization=True)
except Exception as e:
print(f"Error loading vector store: {str(e)}")
return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
prompt1 = query.query_text
if prompt1:
start = time.process_time()
document_chain = create_stuff_documents_chain(llm, prompt)
retriever = vectors.as_retriever()
retrieval_chain = create_retrieval_chain(retriever, document_chain)
response = retrieval_chain.invoke({'input': prompt1})
print("Response time:", time.process_time() - start)
# Apply the cleaning function to the response
cleaned_response = clean_response(response['answer'])
print("Cleaned response:", repr(cleaned_response))
return {"response": cleaned_response}
else:
return {"response": "No Query Found"}
@app.get("/setup")
def setup():
return vector_embedding()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)