import streamlit as st import os from langchain_groq import ChatGroq from langchain_community.document_loaders import WebBaseLoader from langchain_community.embeddings import OllamaEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate from langchain.chains import create_retrieval_chain from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import PyPDFLoader from langchain_community.document_loaders import PyPDFDirectoryLoader from langchain_community.embeddings import HuggingFaceBgeEmbeddings # from langchain.vectorstores.cassandra import Cassandra from langchain_community.vectorstores import Cassandra from langchain_community.llms import Ollama from cassandra.auth import PlainTextAuthProvider import tempfile import cassio from PyPDF2 import PdfReader from cassandra.cluster import Cluster import warnings warnings.filterwarnings("ignore") from dotenv import load_dotenv import time load_dotenv() ASTRA_DB_SECURE_BUNDLE_PATH ='secure-connect-pdf-query-db.zip' os.environ["LANGCHAIN_TRACING_V2"]="true" LANGCHAIN_API_KEY=os.getenv("LANGCHAIN_API_KEY") LANGCHAIN_PROJECT=os.getenv("LANGCHAIN_PROJECT") LANGCHAIN_ENDPOINT=os.getenv("LANGCHAIN_ENDPOINT") ASTRA_DB_APPLICATION_TOKEN=os.getenv("ASTRA_DB_APPLICATION_TOKEN") ASTRA_DB_ID=os.getenv("ASTRA_DB_ID") ASTRA_DB_KEYSPACE=os.getenv("ASTRA_DB_KEYSPACE") ASTRA_DB_API_ENDPOINT=os.getenv("ASTRA_DB_API_ENDPOINT") ASTRA_DB_CLIENT_ID=os.getenv("ASTRA_DB_CLIENT_ID") ASTRA_DB_CLIENT_SECRET=os.getenv("ASTRA_DB_CLIENT_SECRET") ASTRA_DB_TABLE=os.getenv("ASTRA_DB_TABLE") groq_api_key=os.getenv('groq_api_key') cassio.init(token=ASTRA_DB_APPLICATION_TOKEN,database_id=ASTRA_DB_ID,secure_connect_bundle=ASTRA_DB_SECURE_BUNDLE_PATH) cloud_config = { 'secure_connect_bundle': ASTRA_DB_SECURE_BUNDLE_PATH } def doc_loader(pdf_reader): encode_kwargs = {'normalize_embeddings': True} huggigface_embeddings=HuggingFaceBgeEmbeddings( model_name='BAAI/bge-small-en-v1.5', # model_name='sentence-transformers/all-MiniLM-16-v2', model_kwargs={'device':'cpu'}, encode_kwargs=encode_kwargs) loader=PyPDFLoader(pdf_reader) documents=loader.load_and_split() text_splitter=RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=200) final_documents=text_splitter.split_documents(documents) astrasession = Cluster( cloud={"secure_connect_bundle": ASTRA_DB_SECURE_BUNDLE_PATH}, auth_provider=PlainTextAuthProvider("token", ASTRA_DB_APPLICATION_TOKEN), ).connect() # Truncate the existing table astrasession.execute(f'TRUNCATE {ASTRA_DB_KEYSPACE}.{ASTRA_DB_TABLE}') astra_vector_store=Cassandra( embedding=huggigface_embeddings, table_name="qa_mini_demo", session=astrasession, keyspace=ASTRA_DB_KEYSPACE ) astra_vector_store.add_documents(final_documents) return astra_vector_store def prompt_temp(): prompt=ChatPromptTemplate.from_template( """ Answer the question based on provided context only. Your context retrieval mechanism works correclty but your are not providing answer from context. Please provide the most accurate response based on question. {context}, Questions:{input} """ ) return prompt def generate_response(llm,prompt,user_input,vectorstore): document_chain=create_stuff_documents_chain(llm,prompt) retriever=vectorstore.as_retriever(search_type="similarity",search_kwargs={"k":5}) retrieval_chain=create_retrieval_chain(retriever,document_chain) response=retrieval_chain.invoke({"input":user_input}) return response # ['answer'] def main(): st.set_page_config(page_title='Chat Groq Demo') st.header('Chat Groq Demo') user_input=st.text_input('Enter the Prompt here') file=st.file_uploader('Choose Invoice File',type='pdf') submit = st.button("Submit") st.session_state.submit_clicked = False if submit : st.session_state.submit_clicked = True if user_input and file: with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(file.getbuffer()) file_path = temp_file.name # with open(file.name, mode='wb') as w: # # w.write(file.getvalue()) # w.write(file.getbuffer()) llm=ChatGroq(groq_api_key=groq_api_key,model_name="gemma-7b-it") prompt=prompt_temp() vectorstore=doc_loader(file_path) response=generate_response(llm,prompt,user_input,vectorstore) st.write(response['answer']) with st.expander("Document Similarity Search"): for i,doc in enumerate(response['context']): st.write(doc.page_content) st.write('---------------------------------') if __name__=="__main__": main()