Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, HTTPException, Response, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
import uuid | |
from datetime import datetime, timedelta | |
import asyncio | |
from typing import List, Dict, Any | |
from io import BytesIO, StringIO | |
from docx import Document | |
from langchain.docstore.document import Document as langchain_Document | |
from PyPDF2 import PdfReader | |
import csv | |
from langchain.prompts import ChatPromptTemplate, PromptTemplate | |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.memory import ConversationBufferMemory | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain_community.vectorstores import Chroma | |
from langchain.chains import ConversationalRetrievalChain | |
from dotenv import load_dotenv | |
load_dotenv() | |
class Document_Processor: | |
def __init__(self, file_details: List[Dict[Any, str]]): | |
self.file_details = file_details | |
def get_docs(self) -> List[langchain_Document]: | |
docs = [] | |
for file_detail in self.file_details: | |
if file_detail["name"].endswith(".txt"): | |
docs.extend(self.get_txt_docs(file_detail=file_detail)) | |
elif file_detail["name"].endswith(".csv"): | |
docs.extend(self.get_csv_docs(file_detail=file_detail)) | |
elif file_detail["name"].endswith(".docx"): | |
docs.extend(self.get_docx_docs(file_detail=file_detail)) | |
elif file_detail["name"].endswith(".pdf"): | |
docs.extend(self.get_pdf_docs(file_detail=file_detail)) | |
return docs | |
def get_txt_docs(file_detail: Dict[str, Any]) -> List[langchain_Document]: | |
text = file_detail["content"].decode("utf-8") | |
source = file_detail["name"] | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, chunk_overlap=100 | |
) | |
text_docs = text_splitter.create_documents( | |
[text], metadatas=[{"source": source}] | |
) | |
return text_docs | |
def get_csv_docs(file_detail: Dict[str, Any]) -> List[langchain_Document]: | |
csv_data = file_detail["content"] | |
source = file_detail["name"] | |
csv_string = csv_data.decode("utf-8") | |
# Use StringIO to create a file-like object from the string | |
csv_file = StringIO(csv_string) | |
csv_reader = csv.DictReader(csv_file) | |
csv_docs = [] | |
for row in csv_reader: | |
# Convert each row into a dictionary of key/value pairs | |
page_content = "" | |
for key, value in row.items(): | |
page_content += f"{key}: {value}\n" | |
doc = langchain_Document( | |
page_content=page_content, metadata={"source": source} | |
) | |
csv_docs.append(doc) | |
return csv_docs | |
def get_pdf_docs(file_detail: Dict[str, Any]) -> List[langchain_Document]: | |
pdf_content = BytesIO(file_detail["content"]) | |
source = file_detail["name"] | |
reader = PdfReader(pdf_content) | |
pdf_text = "" | |
for page in reader.pages: | |
pdf_text += page.extract_text() + "\n" | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, chunk_overlap=100 | |
) | |
pdf_docs = text_splitter.create_documents( | |
texts=[pdf_text], metadatas=[{"source": source}] | |
) | |
return pdf_docs | |
def get_docx_docs(file_detail: Dict[str, Any]) -> List[langchain_Document]: | |
docx_content = BytesIO(file_detail["content"]) | |
source = file_detail["name"] | |
document = Document(docx_content) | |
docx_text = " ".join([paragraph.text for paragraph in document.paragraphs]) | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, chunk_overlap=100 | |
) | |
docx_docs = text_splitter.create_documents( | |
[docx_text], metadatas=[{"source": source}] | |
) | |
return docx_docs | |
class Conversational_Chain: | |
def __init__(self, file_details: List[Dict[Any, str]]): | |
self.llm_model = ChatOpenAI() | |
self.embeddings = OpenAIEmbeddings() | |
self.file_details = file_details | |
def create_conversational_chain(self): | |
docs = Document_Processor(self.file_details).get_docs() | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", return_messages=True | |
) | |
vectordb = Chroma.from_documents( | |
docs, | |
self.embeddings, | |
) | |
retriever = vectordb.as_retriever() | |
conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=self.llm_model, | |
retriever=retriever, | |
condense_question_prompt=self.get_question_generator_prompt(), | |
combine_docs_chain_kwargs={ | |
"document_prompt": self.get_document_prompt(), | |
"prompt": self.get_final_prompt(), | |
}, | |
memory=memory, | |
) | |
return conversation_chain | |
def get_document_prompt() -> PromptTemplate: | |
document_template = """Document Content:{page_content} | |
Document Path: {source}""" | |
return PromptTemplate( | |
input_variables=["page_content", "source"], | |
template=document_template, | |
) | |
def get_question_generator_prompt() -> PromptTemplate: | |
question_generator_template = """Combine the chat history and follow up question into | |
a standalone question.\n Chat History: {chat_history}\n | |
Follow up question: {question} | |
""" | |
return PromptTemplate.from_template(question_generator_template) | |
def get_final_prompt() -> ChatPromptTemplate: | |
final_prompt_template = """Answer question based on the context and chat_history. | |
If you cannot find answers, ask more related questions from the user. | |
Use only the basename of the file path as name of the documents. | |
Mention document name of the documents you used in your answer. | |
context: | |
{context} | |
chat_history: | |
{chat_history} | |
question: | |
{question} | |
Answer: | |
""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(final_prompt_template), | |
HumanMessagePromptTemplate.from_template("{question}"), | |
] | |
return ChatPromptTemplate.from_messages(messages) | |
class UserSessionManager: | |
def __init__(self): | |
self.sessions = {} | |
self.last_request_time = {} | |
def get_session(self, user_id: str): | |
if user_id not in self.sessions: | |
self.sessions[user_id] = None | |
self.last_request_time = datetime.now() | |
return self.sessions[user_id] | |
def set_session(self, user_id: str, conversational_chain): | |
self.sessions[user_id] = conversational_chain | |
self.last_request_time[user_id] = datetime.now() | |
def delete_inactive_sessions(self, inactive_period: timedelta): | |
current_time = datetime.now() | |
for user_id, last_request_time in list(self.last_request_time.items()): | |
if current_time - last_request_time > inactive_period: | |
del self.sessions[user_id] | |
del self.last_request_time[user_id] | |
app = FastAPI() | |
origins = ["https://viboognesh-react-chat.static.hf.space"] | |
# origins = ["http://localhost:3000"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["GET", "POST"], | |
allow_headers=["*"], | |
) | |
user_session_manager = UserSessionManager() | |
async def update_last_request_time(request: Request, call_next): | |
user_id = request.cookies.get("user_id") | |
if user_id: | |
user_session_manager.last_request_time[user_id] = datetime.now() | |
response = await call_next(request) | |
return response | |
async def check_inactivity(): | |
inactive_period = timedelta(hours=2) | |
while True: | |
await asyncio.sleep(600) | |
user_session_manager.delete_inactive_sessions(inactive_period) | |
async def startup_event(): | |
asyncio.create_task(check_inactivity()) | |
async def upload_files(response: Response, files: List[UploadFile] = File(...)): | |
file_details = [] | |
try: | |
for file in files: | |
content = await file.read() | |
name = f"{file.filename}" | |
details = {"content": content, "name": name} | |
file_details.append(details) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
user_id = response.cookies.get("user_id") | |
if not user_id: | |
user_id = str(uuid.uuid4()) | |
response.set_cookie(key="user_id", value=user_id) | |
try: | |
conversational_chain = Conversational_Chain( | |
file_details | |
).create_conversational_chain() | |
user_session_manager.set_session( | |
user_id=user_id, conversational_chain=conversational_chain | |
) | |
print("conversational_chain_manager created") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
return {"message": "ConversationalRetrievalChain is created. Please ask questions."} | |
async def predict(query: str): | |
user_id = response.cookies.get("user_id") | |
if not user_id: | |
user_id = str(uuid.uuid4()) | |
response.set_cookie(key="user_id", value=user_id) | |
try: | |
conversational_chain = user_session_manager.get_session(user_id=user_id) | |
if conversational_chain is None: | |
system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n" | |
llm_model = ChatOpenAI() | |
response = llm_model.invoke(system_prompt + query) | |
answer = response.content | |
else: | |
response = conversational_chain.invoke(query) | |
answer = response["answer"] | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
print("predict called") | |
return {"answer": answer} |