import streamlit as st from langchain.agents import create_sql_agent,create_react_agent from langchain.agents.agent_toolkits import SQLDatabaseToolkit from langchain.agents.agent_types import AgentType from langchain_groq import ChatGroq from langchain_core.prompts import ChatPromptTemplate from langchain.sql_database import SQLDatabase from sqlalchemy import create_engine from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.chains.retrieval import create_retrieval_chain from langchain_core.output_parsers import StrOutputParser from sqlalchemy.orm import sessionmaker from sqlalchemy import text import sqlite3 from dotenv import load_dotenv from pathlib import Path from PyPDF2 import PdfReader import os import re load_dotenv() os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY") os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN") st.set_page_config("Langchain interaction with DB") st.title("Document QnA with DB interaction") llm = ChatGroq(model="llama3-8b-8192", api_key= os.environ['GROQ_API_KEY']) embeddings = HuggingFaceEmbeddings(model_name = "all-MiniLM-L6-v2") duration_pattern = re.compile(r"(\d+)\s*(min[s]?|minute[s]?)") st.session_state.user_prompt = "" st.session_state.summary = "" pdf_prompt_template = ChatPromptTemplate.from_template(""" Answer the following question from the provided context only. Please provide the most accurate response based on the question {context} Question : {input} """) def get_pdf_text(pdf_docs): text="" for pdf in pdf_docs: pdf_reader= PdfReader(pdf) for page in pdf_reader.pages: text+= page.extract_text() return text def create_vector_embeddings(pdfText): if "vectors" not in st.session_state: st.session_state.docs = get_pdf_text(pdfText) st.session_state.splitter = RecursiveCharacterTextSplitter(chunk_size=1200,chunk_overlap=400) st.session_state.final_docs = st.session_state.splitter.split_text(st.session_state.docs) st.session_state.vectors = FAISS.from_texts(st.session_state.final_docs, embeddings) def configure(): dbfilepath = (Path(__file__).parent /"programme.db").absolute() creator = lambda: sqlite3.connect(f"file:{dbfilepath}",uri= True, check_same_thread=False) return create_engine("sqlite:///", creator= creator) engine = configure() db = SQLDatabase(engine) #ChatGroq(model="gemma2-9b-it" sql_toolkit = SQLDatabaseToolkit(db = db, llm = llm , api_key= os.environ['GROQ_API_KEY']) sql_toolkit.get_tools() prefilled_prompt = "" # if "uploaded_text" in st.session_state: # for m in st.session_state.uploaded_text: # st.error(m) # if 'PACKAGE' in st.session_state.uploaded_text: # prefilled_prompt = "get the entire programme details linked to the package" # else: # prefilled_prompt = "get the entire programme details linked to the document" # query=st.text_input("ask question here", value = prefilled_prompt) def clear_database(): connection = engine.raw_connection() try: # Create a cursor from the raw connection cursor = connection.cursor() # List of tables to clear tables = ["programme", "episode"] # Execute DELETE commands for each table for table in tables: cursor.execute(f"DELETE FROM {table}") # Commit the changes to the database connection.commit() finally: # Ensure the connection is closed properly connection.close() def process_sql_script(sql_script): # Define the keyword to check keyword = 'PACKAGE' # Split the script into lines lines = sql_script.strip().split(';') programme_line = lines[0] if keyword not in programme_line: filtered_script = "\n".join([lines[0]]) else: filtered_script = "\n".join(lines) return filtered_script import re def convert_to_hms(duration): hour_minute_match = re.match(r'(?:(\d+)\s*hour[s]?)?\s*(\d+)\s*min[s]?', duration.lower()) if hour_minute_match: hours = int(hour_minute_match.group(1) or 0) minutes = int(hour_minute_match.group(2) or 0) else: return duration total_seconds = (hours * 60 * 60) + (minutes * 60) hh = total_seconds // 3600 mm = (total_seconds % 3600) // 60 ss = total_seconds % 60 return f"{hh:02}:{mm:02}:{ss:02}" def handleDurationForEachScript(scripts): filtered_data = "" # for script in scripts.split(";"): # # Find all matches for durations like '60 minutes' or '60 mins' # matches = duration_pattern.findall(script) # for match in matches: # duration = f"{match[0]} {match[1]}" # e.g., '60 mins' or '60 minutes' # converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss # script = script.replace(duration, converted_duration).replace('utes','') # Replace in script # if ('episode' not in filtered_data) & ('programme' not in filtered_data): # filtered_data = filtered_data + script pattern = r"'(\d+\s*(?:mins|minutes))'" for script in scripts.split(";"): match = re.search(pattern, script) if match: duration = match.group(1) converted_duration = convert_to_hms(duration) # Convert to hh:mm:ss script = script.replace(duration, converted_duration).replace('utes','') # Replace in script if ('episode' not in filtered_data) & ('programme' not in filtered_data): filtered_data = filtered_data + script return filtered_data def parse_insert_statement(insert_statement): # Extract the table name table_match = re.search(r'INSERT INTO (\w+)', insert_statement) if not table_match: return None, None, None table = table_match.group(1) # Extract columns and values columns_match = re.search(r'\((.*?)\)', insert_statement, re.DOTALL) values_match = re.search(r'VALUES\s*\((.*?)\)', insert_statement, re.DOTALL) if not columns_match or not values_match: return None, None, None columns = columns_match.group(1).replace('"', '').replace('\n', ' ').strip() values = values_match.group(1).replace("'", "").replace('\n', ' ').strip() return table, columns, values def build_data_from_sql(programme_sql, episode_sql=None): data = { 'Table': [], 'Columns': [], 'Values': [] } # Parse the programme insert statement programme_table, programme_columns, programme_values = parse_insert_statement(programme_sql) if programme_table and programme_columns and programme_values: data['Table'].append(programme_table.capitalize()) data['Columns'].append(programme_columns) data['Values'].append(programme_values) # Parse the episode insert statement, if it exists if episode_sql: episode_table, episode_columns, episode_values = parse_insert_statement(episode_sql) if episode_table and episode_columns and episode_values: data['Table'].append(episode_table.capitalize()) data['Columns'].append(episode_columns) data['Values'].append(episode_values) return data with st.sidebar: st.title("Menu:") #if "uploaded_text" not in st.session_state: st.session_state.uploaded_text = st.file_uploader("Upload your Files and Click on the Submit & Process Button", accept_multiple_files=True) if st.button("Click To Process File"): with st.spinner("Processing..."): create_vector_embeddings(st.session_state.uploaded_text) st.write("Vector Database is ready") # if "uploaded_text" in st.session_state and st.session_state.uploaded_text is not None: # uploaded_file_names = [file.name for file in st.session_state.uploaded_text] # if any('PACKAGE' in file_name.upper() for file_name in uploaded_file_names): # prefilled_prompt = "get the entire programme details linked to the package" # else: # prefilled_prompt = "get the entire programme details linked to the document" query=st.text_input("ask question here") if query and "vectors" in st.session_state: st.session_state.user_prompt = query document_chain = create_stuff_documents_chain(llm=llm, prompt= pdf_prompt_template) retriever = st.session_state.vectors.as_retriever() retrieval_chain=create_retrieval_chain(retriever,document_chain) response = retrieval_chain.invoke({"input": st.session_state.user_prompt}) #st.write(response) if response: st.session_state.summary = response['answer'] st.write(response['answer']) prompt=ChatPromptTemplate.from_messages( [ ("system", """ You are a SQL expert. Your task is to generate SQL INSERT scripts based on the provided context. 1. Generate an `INSERT` statement for the `programme` table using the following values: - `ProgrammeTitle` - `ProgrammeType` - `Genre` - `SubGenre` - `Language` - `Duration` Example: 2. After generating the `programme` statement, check the `ProgrammeTitle`: - If the `ProgrammeTitle` contains the keyword `PACKAGE`, generate an additional `INSERT` statement for the `episode` table. - If the `ProgrammeTitle` does **not** contain the keyword `PACKAGE`, **do not** generate an `INSERT` statement for the `episode` table. 3. The `episode` INSERT statement should look like this if the condition is met. EpisodeNumber is always 1 and `EpisodeTitle` should take same data from `ProgrammeTitle`. 4. Include only the SQL insert script(s) as final answer, **donot** include any additional details and notes.Return only the necessary SQL INSERT script(s) based on the current input. Ensure that no `episode` INSERT statement is included if the `ProgrammeTitle` does not contain `'PACKAGE'`. Your output should strictly follow these conditions. Output **only** the final answer without producing any intermediate actions. """ ), ("user","{question}\ ai: ") ]) agent=create_sql_agent(llm=llm,toolkit=sql_toolkit,agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,verbose=True,max_execution_time=100,max_iterations=1000, handle_parsing_errors=True) if st.button("Generate Scripts",type="primary"): try: if st.session_state.summary is not None: response=agent.run(prompt.format_prompt(question=st.session_state.summary)) #with st.expander("Expand here to view scripts"): if "INSERT" in response: final_response = process_sql_script(response) final_response_new = handleDurationForEachScript(final_response) episode_sql = "" splitted_data = [] if "uploaded_text" in st.session_state and st.session_state.uploaded_text is not None: uploaded_file_names = [file.name for file in st.session_state.uploaded_text] if any('PACKAGE' in file_name.upper() for file_name in uploaded_file_names): if ";" in final_response_new: splitted_data = [stmt.strip() for stmt in final_response_new.strip().split(';') if stmt.strip()] elif "\n" in final_response_new: splitted_data = [stmt.strip() for stmt in final_response_new.strip().split('\n') if stmt.strip()] elif "," in final_response_new: splitted_data = [stmt.strip() for stmt in final_response_new.strip().split(',') if stmt.strip()] else: if final_response_new is list: splitted_data = final_response_new else: splitted_data.append(final_response_new) print(splitted_data) if len(splitted_data) > 0: programme_sql = splitted_data[0] + ';' # Re-add semicolon to the programme SQL statement print(f"prog{programme_sql}") if len(splitted_data) > 1: episode_sql = splitted_data[1] #print(f"eps{episode_sql}") data = build_data_from_sql(programme_sql, episode_sql) st.write("### Script Summary") st.table(data) st.write("### Full SQL Scripts") with st.expander("Insert Scripts"): st.code(programme_sql, language='sql') st.code(episode_sql, language='sql') #if episode_sql: #with st.expander("Episode Insert Script"): #st.code(episode_sql, language='sql') #st.code(final_response_new, language = 'sql') clear_database() #st.write(response) except Exception as e: st.error(f"Parsing error from LLM.Retry again !!! \n : {str(e)}") # data = { # 'Table': ['Programme', 'Episode'], # 'Columns': ['ProgrammeTitle, ProgrammeType, ...', 'EpisodeTitle, EpisodeNumber, ...'], # 'Values': ['CHAMSARANG PACKAGE, Series, ...', 'CHAMSARANG PACKAGE, 1, ...'] # } # # Display summary table # st.write("### Script Summary") # st.table(data) # # Display expandable sections for each script # st.write("### Full SQL Scripts") # with st.expander("Programme Insert Script"): # st.code("INSERT INTO programme ...") # with st.expander("Episode Insert Script"): # st.code("INSERT INTO episode ...")