AIE4Midterm / app.py
Technocoloredgeek's picture
Update app.py
2e3feae verified
raw
history blame
3.82 kB
import streamlit as st
import asyncio
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from PyPDF2 import PdfReader
import aiohttp
from io import BytesIO
# Set up API key
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
# Set up prompts
system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer."
system_role_prompt = SystemMessagePromptTemplate.from_template(system_template)
user_prompt_template = "Context:\n{context}\n\nQuestion:\n{question}"
user_role_prompt = HumanMessagePromptTemplate.from_template(user_prompt_template)
# Define RetrievalAugmentedQAPipeline class
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:
self.llm = llm
self.vector_db = vector_db
async def arun_pipeline(self, user_query: str):
context_docs = self.vector_db.similarity_search(user_query, k=2)
context_list = [doc.page_content for doc in context_docs]
context_prompt = "\n".join(context_list)
max_context_length = 12000
if len(context_prompt) > max_context_length:
context_prompt = context_prompt[:max_context_length]
formatted_system_prompt = system_role_prompt.format()
formatted_user_prompt = user_role_prompt.format(question=user_query, context=context_prompt)
response = await self.llm.agenerate([formatted_system_prompt, formatted_user_prompt])
return {"response": response.generations[0][0].text, "context": context_list}
# PDF processing functions
async def fetch_pdf(session, url):
async with session.get(url) as response:
if response.status == 200:
return await response.read()
else:
st.error(f"Failed to fetch PDF from {url}")
return None
async def process_pdf(pdf_content):
pdf_reader = PdfReader(BytesIO(pdf_content))
text = "\n".join([page.extract_text() for page in pdf_reader.pages])
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
return text_splitter.split_text(text)
@st.cache_resource
def initialize_pipeline():
return asyncio.run(main())
# Main execution
async def main():
pdf_urls = [
"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf",
"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf",
]
all_chunks = []
async with aiohttp.ClientSession() as session:
pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])
for pdf_content in pdf_contents:
if pdf_content:
chunks = await process_pdf(pdf_content)
all_chunks.extend(chunks)
st.write(f"Created {len(all_chunks)} chunks from {len(pdf_urls)} PDF files")
embeddings = OpenAIEmbeddings()
vector_db = Chroma.from_texts(all_chunks, embeddings)
chat_openai = ChatOpenAI()
return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)
# Streamlit UI
st.title("AI Bill of Rights Q&A")
pipeline = initialize_pipeline()
user_query = st.text_input("Enter your question about the AI Bill of Rights:")
if user_query:
result = asyncio.run(pipeline.arun_pipeline(user_query))
st.write("Response:")
st.write(result["response"])
st.write("Context used:")
for i, context in enumerate(result["context"], 1):
st.write(f"{i}. {context[:100]}...")
if __name__ == "__main__":
st.run()