|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from contextlib import asynccontextmanager |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_community.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_openai import ChatOpenAI |
|
from langchain_groq import ChatGroq |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
from transformers import pipeline |
|
from bs4 import BeautifulSoup |
|
from dotenv import load_dotenv |
|
from PIL import Image |
|
import base64 |
|
import requests |
|
import docx2txt |
|
import pptx |
|
import os |
|
import utils |
|
|
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
load_dotenv() |
|
|
|
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY") |
|
|
|
os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY") |
|
global image_to_text |
|
image_to_text = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") |
|
yield |
|
|
|
utils.unlink_images("/images") |
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
class APIKey(BaseModel): |
|
api_key: str |
|
|
|
|
|
class FileInfo(BaseModel): |
|
file_path: str |
|
file_type: str |
|
|
|
|
|
class Image(BaseModel): |
|
image_path: str |
|
|
|
|
|
class Website(BaseModel): |
|
website_link: str |
|
|
|
|
|
class Question(BaseModel): |
|
question: str |
|
resource: str |
|
|
|
|
|
|
|
def format_docs(docs): |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
def encode_image(image_path): |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
|
|
@app.get("/") |
|
async def welcome(): |
|
return "Welcome to Brainbot!" |
|
|
|
|
|
@app.post("/set_api_key") |
|
async def set_api_key(api_key: APIKey): |
|
os.environ["OPENAI_API_KEY"] = api_key.api_key |
|
return "API key set successfully!" |
|
|
|
|
|
|
|
@app.post("/load_file/{llm}") |
|
async def load_file(llm: str, file_info: FileInfo): |
|
file_path = file_info.file_path |
|
file_type = file_info.file_type |
|
|
|
|
|
try: |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
|
|
|
if file_type == "application/pdf": |
|
|
|
loader = PyPDFLoader(file_path) |
|
docs = loader.load() |
|
elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": |
|
|
|
text = docx2txt.process(file_path) |
|
docs = text_splitter.create_documents([text]) |
|
elif file_type == "text/plain": |
|
|
|
with open(file_path, 'r') as file: |
|
text = file.read() |
|
docs = text_splitter.create_documents([text]) |
|
elif file_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation": |
|
|
|
presentation = pptx.Presentation(file_path) |
|
|
|
slide_texts = [] |
|
|
|
|
|
for slide in presentation.slides: |
|
|
|
slide_text = "" |
|
|
|
|
|
for shape in slide.shapes: |
|
if hasattr(shape, "text"): |
|
slide_text += shape.text + "\n" |
|
|
|
slide_texts.append(slide_text.strip()) |
|
|
|
docs = text_splitter.create_documents(slide_texts) |
|
elif file_type == "text/html": |
|
|
|
with open(file_path, 'r') as file: |
|
soup = BeautifulSoup(file, 'html.parser') |
|
text = soup.get_text() |
|
docs = text_splitter.create_documents([text]) |
|
|
|
|
|
os.unlink(file_path) |
|
|
|
|
|
documents = text_splitter.split_documents(docs) |
|
|
|
if llm == "GPT-4": |
|
embeddings = OpenAIEmbeddings() |
|
elif llm == "GROQ": |
|
embeddings = HuggingFaceEmbeddings() |
|
|
|
|
|
global file_vectorstore |
|
file_vectorstore = FAISS.from_documents(documents, embeddings) |
|
except Exception as e: |
|
|
|
raise HTTPException(status_code=500, detail=str(e.with_traceback)) |
|
return "File uploaded successfully!" |
|
|
|
|
|
|
|
@app.post("/image/{llm}") |
|
async def interpret_image(llm: str, image: Image): |
|
try: |
|
|
|
base64_image = encode_image(image.image_path) |
|
|
|
if llm == "GPT-4": |
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" |
|
} |
|
|
|
payload = { |
|
"model": "gpt-4-turbo", |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": "What's in this image?" |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/jpeg;base64,{base64_image}" |
|
} |
|
} |
|
] |
|
} |
|
], |
|
"max_tokens": 300 |
|
} |
|
|
|
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) |
|
response = response.json() |
|
|
|
description = response["choices"][0]["message"]["content"] |
|
elif llm == "GROQ": |
|
|
|
response = image_to_text(image.image_path) |
|
|
|
description = response[0]["generated_text"] |
|
chat = ChatGroq(temperature=0, groq_api_key=os.environ["GROQ_API_KEY"], model_name="Llama3-8b-8192") |
|
system = "You are an assistant to understand and interpret images." |
|
human = "{text}" |
|
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)]) |
|
|
|
chain = prompt | chat |
|
text = f"Explain the following image description in a small paragraph. {description}" |
|
response = chain.invoke({"text": text}) |
|
description = str.capitalize(description) + ". " + response.content |
|
except Exception as e: |
|
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
return description |
|
|
|
|
|
|
|
|
|
@app.post("/load_link/{llm}") |
|
async def website_info(llm: str, link: Website): |
|
try: |
|
|
|
loader = WebBaseLoader(web_paths=(link.website_link,),) |
|
|
|
global web_documents |
|
web_documents = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
documents = text_splitter.split_documents(web_documents) |
|
|
|
if llm == "GPT-4": |
|
embeddings = OpenAIEmbeddings() |
|
elif llm == "GROQ": |
|
embeddings = HuggingFaceEmbeddings() |
|
|
|
|
|
global website_vectorstore |
|
website_vectorstore = FAISS.from_documents(documents, embeddings) |
|
except Exception as e: |
|
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
return "Website loaded successfully!" |
|
|
|
|
|
|
|
@app.post("/answer_with_chat_history/{llm}") |
|
async def get_answer_with_chat_history(llm: str, question: Question): |
|
user_question = question.question |
|
resource = question.resource |
|
selected_llm = llm |
|
|
|
try: |
|
|
|
if selected_llm == "GPT-4": |
|
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) |
|
elif selected_llm == "GROQ": |
|
llm = ChatGroq(groq_api_key=os.environ["GROQ_API_KEY"], model_name="Llama3-8b-8192") |
|
|
|
|
|
if resource == "file": |
|
retriever = file_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5}) |
|
elif resource == "web": |
|
retriever = website_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5}) |
|
|
|
|
|
contextualize_q_system_prompt = """Given a chat history and the latest user question \ |
|
which might reference context in the chat history, formulate a standalone question \ |
|
which can be understood without the chat history. Do NOT answer the question, \ |
|
just reformulate it if needed and otherwise return it as is.""" |
|
contextualize_q_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", contextualize_q_system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
] |
|
) |
|
history_aware_retriever = create_history_aware_retriever( |
|
llm, retriever, contextualize_q_prompt |
|
) |
|
|
|
|
|
qa_system_prompt = """You are an assistant for question-answering tasks. \ |
|
Use the following pieces of retrieved context to answer the question. \ |
|
If you don't know the answer, just say that you don't know. \ |
|
Use three sentences maximum and keep the answer concise.\ |
|
{context}""" |
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", qa_system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
] |
|
) |
|
|
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) |
|
|
|
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) |
|
|
|
|
|
store = {} |
|
def get_session_history(session_id: str) -> BaseChatMessageHistory: |
|
if session_id not in store: |
|
store[session_id] = ChatMessageHistory() |
|
return store[session_id] |
|
|
|
conversational_rag_chain = RunnableWithMessageHistory( |
|
rag_chain, |
|
get_session_history, |
|
input_messages_key="input", |
|
history_messages_key="chat_history", |
|
output_messages_key="answer", |
|
) |
|
|
|
response = conversational_rag_chain.invoke( |
|
{"input": user_question}, |
|
config={ |
|
"configurable": {"session_id": "abc123"} |
|
}, |
|
)["answer"] |
|
except Exception as e: |
|
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
return response |