File size: 4,275 Bytes
357a027
c3c4685
f977e65
357a027
 
 
 
 
 
8214bc3
 
 
357a027
 
 
 
8214bc3
 
 
f991db7
024c031
 
 
 
 
 
 
 
125e5be
024c031
 
 
 
 
 
 
 
 
 
f910e67
 
 
 
 
 
 
357a027
125e5be
f910e67
 
357a027
7f00801
8214bc3
7f00801
 
357a027
7f00801
357a027
 
f910e67
357a027
 
f910e67
 
 
 
 
 
8214bc3
 
f910e67
 
 
 
 
357a027
024c031
 
 
 
 
 
 
 
 
357a027
f910e67
 
 
 
 
 
8214bc3
f910e67
 
 
 
 
8214bc3
 
 
f910e67
 
 
 
 
 
 
 
8214bc3
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import tempfile
import base64
import os
from src.utils.ingest_text import create_vector_database
from src.utils.ingest_image import extract_and_store_images
from src.utils.text_qa import qa_bot
from src.utils.image_qa import query_and_print_results
import nest_asyncio
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from dotenv import load_dotenv
nest_asyncio.apply()

load_dotenv()

memory_storage = StreamlitChatMessageHistory(key="chat_messages")
memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)

image_bg = r"data/pexels-fwstudio-33348-129731.jpg"

def add_bg_from_local(image_file):
    with open(image_file, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read())
    st.markdown(f"""<style>.stApp {{background-image: url(data:image/{"png"};base64,{encoded_string.decode()});
    background-size: cover}}</style>""", unsafe_allow_html=True)
add_bg_from_local(image_bg)

#st.header("Welcome")
#st.set_page_config(layout='wide', page_title="Virtual Tutor")
st.markdown("""
    <svg width="600" height="100">
        <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
         stroke-width="0.3" stroke-linejoin="round">MULTIMODAL RAG CHAT
        </text>
    </svg>
""", unsafe_allow_html=True)


def get_answer(query, chain):
    try:
        response = chain.invoke(query)
        return response['result']
    except Exception as e:
        st.error(f"Error in get_answer: {e}")
        return None

#st.title("MULTIMODAL DOC QA")

uploaded_file = st.file_uploader("File upload", type="pdf")
if uploaded_file is not None:
    temp_file_path = os.path.join("temp", uploaded_file.name)
    os.makedirs("temp", exist_ok=True)
    with open(temp_file_path, "wb") as f:
        f.write(uploaded_file.getbuffer())

    path = os.path.abspath(temp_file_path)
    st.write(f"File saved to: {path}")

    st.write("Document uploaded successfully!")

if st.button("Start Processing"):
    if uploaded_file is not None:
        with st.spinner("Processing"):
            try:
                client = create_vector_database(path)
                image_vdb = extract_and_store_images(path)
                chain = qa_bot(client)
                st.session_state['chain'] = chain
                st.session_state['image_vdb'] = image_vdb
                st.success("Processing complete.")
            except Exception as e:
                st.error(f"Error during processing: {e}")
    else:
        st.error("Please upload a file before starting processing.")

st.markdown("""
    <style> 
    .stChatInputContainer > div {
    background-color: #000000;
    }
    </style>
    """, unsafe_allow_html=True)


if user_input := st.chat_input("User Input"):
    if 'chain' in st.session_state and 'image_vdb' in st.session_state:
        chain = st.session_state['chain']
        image_vdb = st.session_state['image_vdb']

        with st.chat_message("user"):
            st.markdown(user_input)
        memory.save_context({"role": "user", "content": user_input})

        with st.spinner("Generating Response..."):
            response = get_answer(user_input, chain)
            if response:
                st.markdown(response)
                with st.chat_message("assistant"):
                    st.markdown(response)
                memory.save_context({"role": "assistant", "content": response})
                try:
                    query_and_print_results(image_vdb, user_input)
                except Exception as e:
                    st.error(f"Error querying image database: {e}")
            else:
                st.error("Failed to generate response.")
    else:
        st.error("Please start processing before entering user input.")

if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

for i, msg in enumerate(memory_storage.messages):
    name = "user" if i % 2 == 0 else "assistant"
    st.chat_message(name).markdown(msg.content)