import streamlit as st from langchain.agents import create_sql_agent,create_react_agent from langchain.agents.agent_toolkits import SQLDatabaseToolkit from langchain.agents.agent_types import AgentType from langchain_groq import ChatGroq from langchain_core.prompts import ChatPromptTemplate from langchain.sql_database import SQLDatabase from sqlalchemy import create_engine from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.chains.retrieval import create_retrieval_chain from langchain_core.output_parsers import StrOutputParser from sqlalchemy.orm import sessionmaker from sqlalchemy import text import sqlite3 from dotenv import load_dotenv from pathlib import Path from PyPDF2 import PdfReader import os import re load_dotenv() os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY") os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN") st.set_page_config("Langchain interaction with DB") st.title("Langchain with DB") llm = ChatGroq(model="llama3-8b-8192", api_key= os.environ['GROQ_API_KEY']) embeddings = HuggingFaceEmbeddings(model_name = "all-MiniLM-L6-v2") duration_pattern = re.compile(r"(\d+)\s*(min[s]?|minute[s]?)") st.session_state.user_prompt = "" st.session_state.summary = "" pdf_prompt_template = ChatPromptTemplate.from_template(""" Answer the following question from the provided context only. Please provide the most accurate response based on the question {context} Question : {input} """) def get_pdf_text(pdf_docs): text="" for pdf in pdf_docs: pdf_reader= PdfReader(pdf) for page in pdf_reader.pages: text+= page.extract_text() return text def create_vector_embeddings(pdfText): if "vectors" not in st.session_state: st.session_state.docs = get_pdf_text(pdfText) st.session_state.splitter = RecursiveCharacterTextSplitter(chunk_size=1200,chunk_overlap=400) st.session_state.final_docs = st.session_state.splitter.split_text(st.session_state.docs) st.session_state.vectors = FAISS.from_texts(st.session_state.final_docs, embeddings) def configure(): dbfilepath = (Path(__file__).parent /"programme.db").absolute() creator = lambda: sqlite3.connect(f"file:{dbfilepath}",uri= True, check_same_thread=False) return create_engine("sqlite:///", creator= creator) engine = configure() db = SQLDatabase(engine) #ChatGroq(model="gemma2-9b-it" sql_toolkit = SQLDatabaseToolkit(db = db, llm = llm , api_key= os.environ['GROQ_API_KEY']) sql_toolkit.get_tools() query=st.text_input("ask question here") def clear_database(): connection = engine.raw_connection() try: # Create a cursor from the raw connection cursor = connection.cursor() # List of tables to clear tables = ["programme", "episode"] # Execute DELETE commands for each table for table in tables: cursor.execute(f"DELETE FROM {table}") # Commit the changes to the database connection.commit() finally: # Ensure the connection is closed properly connection.close() def process_sql_script(sql_script): # Define the keyword to check keyword = 'PACKAGE' # Split the script into lines lines = sql_script.strip().split(';') programme_line = lines[0] if keyword not in programme_line: filtered_script = "\n".join([lines[0]]) else: filtered_script = "\n".join(lines) return filtered_script import re def convert_to_hms(duration): hour_minute_match = re.match(r'(?:(\d+)\s*hour[s]?)?\s*(\d+)\s*min[s]?', duration.lower()) if hour_minute_match: hours = int(hour_minute_match.group(1) or 0) minutes = int(hour_minute_match.group(2) or 0) else: return duration total_seconds = (hours * 60 * 60) + (minutes * 60) hh = total_seconds // 3600 mm = (total_seconds % 3600) // 60 ss = total_seconds % 60 return f"{hh:02}:{mm:02}:{ss:02}" def handleDurationForEachScript(scripts): filtered_data = "" for script in scripts.split(";"): # Find all matches for durations like '60 minutes' or '60 mins' matches = duration_pattern.findall(script) for match in matches: duration = f"{match[0]} {match[1]}" # e.g., '60 mins' or '60 minutes' converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss script = script.replace(duration, converted_duration).replace('utes','') # Replace in script if ('episode' not in filtered_data) & ('programme' not in filtered_data): filtered_data = filtered_data + script return filtered_data with st.sidebar: st.title("Menu:") #if "uploaded_text" not in st.session_state: st.session_state.uploaded_text = st.file_uploader("Upload your Files and Click on the Submit & Process Button", accept_multiple_files=True) if st.button("Click To Process File"): with st.spinner("Processing..."): create_vector_embeddings(st.session_state.uploaded_text) st.write("Vector Database is ready") if query: st.session_state.user_prompt = query document_chain = create_stuff_documents_chain(llm=llm, prompt= pdf_prompt_template) retriever = st.session_state.vectors.as_retriever() retrieval_chain=create_retrieval_chain(retriever,document_chain) response = retrieval_chain.invoke({"input": st.session_state.user_prompt}) #st.write(response) if response: st.session_state.summary = response['answer'] st.write(response['answer']) prompt=ChatPromptTemplate.from_messages( [ ("system", """ You are a SQL expert. Your task is to generate SQL INSERT scripts based on the provided context. 1. Generate an `INSERT` statement for the `programme` table using the following values: - `ProgrammeTitle` - `ProgrammeType` - `Genre` - `SubGenre` - `Language` - `Duration` Example: 2. After generating the `programme` statement, check the `ProgrammeTitle`: - If the `ProgrammeTitle` contains the keyword `PACKAGE`, generate an additional `INSERT` statement for the `episode` table. - If the `ProgrammeTitle` does **not** contain the keyword `PACKAGE`, **do not** generate an `INSERT` statement for the `episode` table. 3. The `episode` INSERT statement should look like this if the condition is met. EpisodeNumber is always 1 and `EpisodeTitle` should take same data from `ProgrammeTitle`. 4. Include only the SQL insert script(s) as final answer, **donot** include any additional details and notes.Return only the necessary SQL INSERT script(s) based on the current input. Ensure that no `episode` INSERT statement is included if the `ProgrammeTitle` does not contain `'PACKAGE'`. Your output should strictly follow these conditions. Output **only** the final answer without producing any intermediate actions. """ ), ("user","{question}\ ai: ") ]) agent=create_sql_agent(llm=llm,toolkit=sql_toolkit,agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,verbose=True,max_execution_time=100,max_iterations=1000, handle_parsing_errors=True) if st.button("Generate Scripts",type="primary"): try: if st.session_state.summary is not None: response=agent.run(prompt.format_prompt(question=st.session_state.summary)) with st.expander("Expand here to view scripts"): if "INSERT" in response: final_response = process_sql_script(response) final_response_new = handleDurationForEachScript(final_response) print(final_response_new) st.code(final_response_new, language = 'sql') clear_database() #st.write(response) except Exception as e: st.error(f"Parsing error from LLM.Retry again !!! \n : {str(e)}") # data = { # 'Table': ['Programme', 'Episode'], # 'Columns': ['ProgrammeTitle, ProgrammeType, ...', 'EpisodeTitle, EpisodeNumber, ...'], # 'Values': ['CHAMSARANG PACKAGE, Series, ...', 'CHAMSARANG PACKAGE, 1, ...'] # } # # Display summary table # st.write("### Script Summary") # st.table(data) # # Display expandable sections for each script # st.write("### Full SQL Scripts") # with st.expander("Programme Insert Script"): # st.code("INSERT INTO programme ...") # with st.expander("Episode Insert Script"): # st.code("INSERT INTO episode ...")