0504ankitsharma's picture
upload 9 files
232f6b1 verified
raw
history blame
5.45 kB
import os
import re
from openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI
from pydantic import BaseModel
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import time
def clean_response(response):
# Remove any leading/trailing whitespace, including newlines
cleaned = response.strip()
# Remove any enclosing quotation marks
cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
# Replace multiple newlines with a single newline
cleaned = re.sub(r'\n+', '\n', cleaned)
# Remove any remaining '\n' characters
cleaned = cleaned.replace('\\n', '')
return cleaned
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
openai_api_key = os.environ.get('OPENAI_API_KEY')
llm = ChatOpenAI(
api_key=openai_api_key,
model_name="gpt-4-turbo-preview", # or "gpt-3.5-turbo" for a more economical option
temperature=0.7
)
@app.get("/")
def read_root():
return {"Hello": "World"}
class Query(BaseModel):
query_text: str
prompt = ChatPromptTemplate.from_template(
"""
You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
You may elaborate on your answers slightly to provide more information, but avoid sounding boastful or exaggerating. Stay focused on the context provided.
If the query is not related to TIET or falls outside the context of education, respond with:
"Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
For more information, please contact at our toll-free number: 18002024100 or E-mail us at admissions@thapar.edu
<context>
{context}
</context>
Question: {input}
"""
)
def vector_embedding():
try:
file_path = "./data/Data.docx"
if not os.path.exists(file_path):
print(f"The file {file_path} does not exist.")
return {"response": "Error: Data file not found"}
loader = DocxLoader(file_path)
documents = loader.load()
print(f"Loaded document: {file_path}")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
chunks = text_splitter.split_documents(documents)
print(f"Created {len(chunks)} chunks.")
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
db = FAISS.from_documents(chunks, model_norm)
db.save_local("./vectors_db")
print("Vector store created and saved successfully.")
return {"response": "Vector Store DB Is Ready"}
except Exception as e:
print(f"An error occurred: {str(e)}")
return {"response": f"Error: {str(e)}"}
def get_embeddings():
model_name = "BAAI/bge-base-en"
encode_kwargs = {'normalize_embeddings': True}
model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
return model_norm
@app.post("/chat") # Changed from /anthropic to /chat
def read_item(query: Query):
try:
embeddings = get_embeddings()
vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
except Exception as e:
print(f"Error loading vector store: {str(e)}")
return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
prompt1 = query.query_text
if prompt1:
start = time.process_time()
document_chain = create_stuff_documents_chain(llm, prompt)
retriever = vectors.as_retriever()
retrieval_chain = create_retrieval_chain(retriever, document_chain)
response = retrieval_chain.invoke({'input': prompt1})
print("Response time:", time.process_time() - start)
# Apply the cleaning function to the response
cleaned_response = clean_response(response['answer'])
# For debugging, print the cleaned response
print("Cleaned response:", repr(cleaned_response))
return cleaned_response
else:
return "No Query Found"
@app.get("/setup")
def setup():
return vector_embedding()
# Uncomment this to check if the API key is set
# print(f"API key set: {'Yes' if os.environ.get('OPENAI_API_KEY') else 'No'}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)