from fastapi import FastAPI, HTTPException from pydantic import BaseModel from contextlib import asynccontextmanager from langchain_community.document_loaders import PyPDFLoader from langchain_community.document_loaders import WebBaseLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_openai import OpenAIEmbeddings from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_openai import ChatOpenAI from langchain_groq import ChatGroq from langchain.chains import create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory from transformers import pipeline from bs4 import BeautifulSoup from dotenv import load_dotenv from PIL import Image import base64 import requests import docx2txt import pptx import os import utils ## APPLICATION LIFESPAN # Load the environment variables using FastAPI lifespan event so that they are available throughout the application @asynccontextmanager async def lifespan(app: FastAPI): # Load the environment variables load_dotenv() #os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY") ## Langsmith tracking os.environ["LANGCHAIN_TRACING_V2"] = "true" # Enable tracing to capture all the monitoring results os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY") ## load the Groq API key os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY") os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN") os.environ['NGROK_AUTHTOKEN'] = os.getenv("NGROK_AUTHTOKEN") global image_to_text image_to_text = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") yield # Delete all the temporary images utils.unlink_images("/images") ## FASTAPI APP # Initialize the FastAPI app app = FastAPI(lifespan=lifespan, docs_url="/") ## PYDANTIC MODELS # Define an APIKey Pydantic model for the request body class APIKey(BaseModel): api_key: str # Define a FileInfo Pydantic model for the request body class FileInfo(BaseModel): file_path: str file_type: str # Define an Image Pydantic model for the request body class Image(BaseModel): image_path: str # Define a Website Pydantic model for the request body class Website(BaseModel): website_link: str # Define a Question Pydantic model for the request body class Question(BaseModel): question: str resource: str ## FUNCTIONS # Function to combine all documents def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) # Function to encode the image def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') ## FASTAPI ENDPOINTS ## GET - / @app.get("/") async def welcome(): return "Welcome to Brainbot!" ## POST - /set_api_key @app.post("/set_api_key") async def set_api_key(api_key: APIKey): os.environ["OPENAI_API_KEY"] = api_key.api_key return "API key set successfully!" ## POST - /load_file # Load the file, split it into document chunks, and upload the document embeddings into a vectorstore @app.post("/load_file/{llm}") async def load_file(llm: str, file_info: FileInfo): file_path = file_info.file_path file_type = file_info.file_type # Read the file and split it into document chunks try: # Initialize the text splitter text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # Check the file type and load each file according to its type if file_type == "application/pdf": # Read pdf file loader = PyPDFLoader(file_path) docs = loader.load() elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": # Read docx file text = docx2txt.process(file_path) docs = text_splitter.create_documents([text]) elif file_type == "text/plain": # Read txt file with open(file_path, 'r') as file: text = file.read() docs = text_splitter.create_documents([text]) elif file_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation": # Read pptx file presentation = pptx.Presentation(file_path) # Initialize an empty list to store slide texts slide_texts = [] # Iterate through slides and extract text for slide in presentation.slides: # Initialize an empty string to store text for each slide slide_text = "" # Iterate through shapes in the slide for shape in slide.shapes: if hasattr(shape, "text"): slide_text += shape.text + "\n" # Add shape text to slide text # Append slide text to the list slide_texts.append(slide_text.strip()) docs = text_splitter.create_documents(slide_texts) elif file_type == "text/html": # Read html file with open(file_path, 'r') as file: soup = BeautifulSoup(file, 'html.parser') text = soup.get_text() docs = text_splitter.create_documents([text]) # Delete the temporary file os.unlink(file_path) # Split the document into chunks documents = text_splitter.split_documents(docs) if llm == "GPT-4": embeddings = OpenAIEmbeddings() elif llm == "GROQ": embeddings = HuggingFaceEmbeddings() # Save document embeddings into the FAISS vectorstore global file_vectorstore file_vectorstore = FAISS.from_documents(documents, embeddings) except Exception as e: # Handle errors raise HTTPException(status_code=500, detail=str(e.with_traceback)) return "File uploaded successfully!" ## POST - /image # Interpret the image using the LLM - OpenAI Vision @app.post("/image/{llm}") async def interpret_image(llm: str, image: Image): try: # Get the base64 string base64_image = encode_image(image.image_path) if llm == "GPT-4": headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" } payload = { "model": "gpt-4-turbo", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "What's in this image?" }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}" } } ] } ], "max_tokens": 300 } response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) response = response.json() # Extract description about the image description = response["choices"][0]["message"]["content"] elif llm == "GROQ": # Use image-to-text model from Hugging Face response = image_to_text(image.image_path) # Extract description about the image description = response[0]["generated_text"] chat = ChatGroq(temperature=0, groq_api_key=os.environ["GROQ_API_KEY"], model_name="Llama3-8b-8192") system = "You are an assistant to understand and interpret images." human = "{text}" prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)]) chain = prompt | chat text = f"Explain the following image description in a small paragraph. {description}" response = chain.invoke({"text": text}) description = str.capitalize(description) + ". " + response.content except Exception as e: # Handle errors raise HTTPException(status_code=500, detail=str(e)) return description ## POST - load_link # Load the website content through scraping, split it into document chunks, and upload the document # embeddings into a vectorstore @app.post("/load_link/{llm}") async def website_info(llm: str, link: Website): try: # load, chunk, and index the content of the html page loader = WebBaseLoader(web_paths=(link.website_link,),) global web_documents web_documents = loader.load() # split the document into chunks text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) documents = text_splitter.split_documents(web_documents) if llm == "GPT-4": embeddings = OpenAIEmbeddings() elif llm == "GROQ": embeddings = HuggingFaceEmbeddings() # Save document embeddings into the FAISS vectorstore global website_vectorstore website_vectorstore = FAISS.from_documents(documents, embeddings) except Exception as e: # Handle errors raise HTTPException(status_code=500, detail=str(e)) return "Website loaded successfully!" ## POST - /answer_with_chat_history # Retrieve the answer to the question using LLM and the RAG chain maintaining the chat history @app.post("/answer_with_chat_history/{llm}") async def get_answer_with_chat_history(llm: str, question: Question): user_question = question.question resource = question.resource selected_llm = llm try: # Initialize the LLM if selected_llm == "GPT-4": llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) elif selected_llm == "GROQ": llm = ChatGroq(groq_api_key=os.environ["GROQ_API_KEY"], model_name="Llama3-8b-8192") # extract relevant context from the document using the retriever with similarity search if resource == "file": retriever = file_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5}) elif resource == "web": retriever = website_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5}) ### Contextualize question ### contextualize_q_system_prompt = """Given a chat history and the latest user question \ which might reference context in the chat history, formulate a standalone question \ which can be understood without the chat history. Do NOT answer the question, \ just reformulate it if needed and otherwise return it as is.""" contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) ### Answer question ### qa_system_prompt = """You are an assistant for question-answering tasks. \ Use the following pieces of retrieved context to answer the question. \ If you don't know the answer, just say that you don't know. \ Use three sentences maximum and keep the answer concise.\ {context}""" qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) ### Statefully manage chat history ### store = {} def get_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in store: store[session_id] = ChatMessageHistory() return store[session_id] conversational_rag_chain = RunnableWithMessageHistory( rag_chain, get_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) response = conversational_rag_chain.invoke( {"input": user_question}, config={ "configurable": {"session_id": "abc123"} }, # constructs a key "abc123" in `store`. )["answer"] except Exception as e: # Handle errors raise HTTPException(status_code=500, detail=str(e)) return response