srinidhidevaraj's picture
Update app.py
b8901d3 verified
raw
history blame contribute delete
No virus
5.05 kB
import streamlit as st
import os
from langchain_groq import ChatGroq
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
# from langchain.vectorstores.cassandra import Cassandra
from langchain_community.vectorstores import Cassandra
from langchain_community.llms import Ollama
from cassandra.auth import PlainTextAuthProvider
import tempfile
import cassio
from PyPDF2 import PdfReader
from cassandra.cluster import Cluster
import warnings
warnings.filterwarnings("ignore")
from dotenv import load_dotenv
import time
load_dotenv()
ASTRA_DB_SECURE_BUNDLE_PATH ='secure-connect-pdf-query-db.zip'
os.environ["LANGCHAIN_TRACING_V2"]="true"
LANGCHAIN_API_KEY=os.getenv("LANGCHAIN_API_KEY")
LANGCHAIN_PROJECT=os.getenv("LANGCHAIN_PROJECT")
LANGCHAIN_ENDPOINT=os.getenv("LANGCHAIN_ENDPOINT")
ASTRA_DB_APPLICATION_TOKEN=os.getenv("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_ID=os.getenv("ASTRA_DB_ID")
ASTRA_DB_KEYSPACE=os.getenv("ASTRA_DB_KEYSPACE")
ASTRA_DB_API_ENDPOINT=os.getenv("ASTRA_DB_API_ENDPOINT")
ASTRA_DB_CLIENT_ID=os.getenv("ASTRA_DB_CLIENT_ID")
ASTRA_DB_CLIENT_SECRET=os.getenv("ASTRA_DB_CLIENT_SECRET")
ASTRA_DB_TABLE=os.getenv("ASTRA_DB_TABLE")
groq_api_key=os.getenv('groq_api_key')
cassio.init(token=ASTRA_DB_APPLICATION_TOKEN,database_id=ASTRA_DB_ID,secure_connect_bundle=ASTRA_DB_SECURE_BUNDLE_PATH)
cloud_config = {
'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH
}
def doc_loader(pdf_reader):
encode_kwargs = {'normalize_embeddings': True}
huggigface_embeddings=HuggingFaceBgeEmbeddings(
model_name='BAAI/bge-small-en-v1.5',
# model_name='sentence-transformers/all-MiniLM-16-v2',
model_kwargs={'device':'cpu'},
encode_kwargs=encode_kwargs)
loader=PyPDFLoader(pdf_reader)
documents=loader.load_and_split()
text_splitter=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=200)
final_documents=text_splitter.split_documents(documents)
astrasession = Cluster(
cloud={"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH},
auth_provider=PlainTextAuthProvider("token", ASTRA_DB_APPLICATION_TOKEN),
).connect()
# Truncate the existing table
astrasession.execute(f'TRUNCATE {ASTRA_DB_KEYSPACE}.{ASTRA_DB_TABLE}')
astra_vector_store=Cassandra(
embedding=huggigface_embeddings,
table_name="qa_mini_demo",
session=astrasession,
keyspace=ASTRA_DB_KEYSPACE
)
astra_vector_store.add_documents(final_documents)
return astra_vector_store
def prompt_temp():
prompt=ChatPromptTemplate.from_template(
"""
Answer the question based on the provided context only.
Please provide the most accurate response based on the question.
{context},
Questions:{input}
"""
)
return prompt
def generate_response(llm,prompt,user_input,vectorstore):
document_chain=create_stuff_documents_chain(llm,prompt)
retriever=vectorstore.as_retriever(search_type="similarity",search_kwargs={"k":5})
retrieval_chain=create_retrieval_chain(retriever,document_chain)
response=retrieval_chain.invoke({"input":user_input})
return response
# ['answer']
def main():
st.set_page_config(page_title='Chat Groq Demo')
st.header('Chat Groq Demo')
user_input=st.text_input('Enter the Prompt here')
file=st.file_uploader('Choose Invoice File',type='pdf')
submit = st.button("Submit")
st.session_state.submit_clicked = False
if submit :
st.session_state.submit_clicked = True
if user_input and file:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(file.getbuffer())
file_path = temp_file.name
# with open(file.name, mode='wb') as w:
# # w.write(file.getvalue())
# w.write(file.getbuffer())
llm=ChatGroq(groq_api_key=groq_api_key,model_name="gemma-7b-it")
prompt=prompt_temp()
vectorstore=doc_loader(file_path)
response=generate_response(llm,prompt,user_input,vectorstore)
st.write(response['answer'])
# with st.expander("Document Similarity Search"):
# for i,doc in enumerate(response['context']):
# st.write(doc.page_content)
# st.write('---------------------------------')
if __name__=="__main__":
main()