Spaces:
Running
Running
File size: 5,461 Bytes
2d8b8bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import streamlit as st
from llama_index.core import Settings
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.node_parser import SentenceSplitter
from llama_index.vector_stores.milvus import MilvusVectorStore
from llama_index.embeddings.nvidia import NVIDIAEmbedding
from llama_index.llms.nvidia import NVIDIA
from llama_index.core.storage.chat_store import SimpleChatStore
from llama_index.core.memory import ChatMemoryBuffer
from document_processor import load_multimodal_data, load_data_from_directory
from utils import set_environment_variables
# Set up the page configuration
st.set_page_config(layout="wide")
# Initialize settings
def initialize_setting():
Settings.embed_model = NVIDIAEmbedding(model="nvidia/nv-embedqa-e5-v5", truncate="END")
Settings.llm = NVIDIA(model="meta/llama-3.1-70b-instruct")
Settings.text_splitter = SentenceSplitter(chunk_size=600)
# Create index from documents
def create_index(documents):
vector_store = MilvusVectorStore(
host = "127.0.0.1",
port = 19530,
dim = 1024
)
# vector_store = MilvusVectorStore(uri="./milvus_demo.db", dim=1024, overwrite=True) #For CPU only vector store
storage_context = StorageContext.from_defaults(vector_store=vector_store)
return VectorStoreIndex.from_documents(documents, storage_context=storage_context)
# Function to generate default response format
def generate_default_response():
return {
"Visible Text Extraction": "English Tea Time, Chai Spice Tea, Ginger Tea, Lemon Ginger Tea, Raspberry Hibiscus Tea",
"Inferred Location/Scene": "A light-colored countertop with five tea boxes. Simple background with no other objects.",
"Date/Time of Image": "time of context image example(Timestamp: 2024-11-28 17:14:48)"
}
# Main function to run the Streamlit app
def main():
set_environment_variables()
initialize_setting()
col1, col2 = st.columns([1, 2])
with col1:
st.title("Multimodal RAG")
input_method = st.radio("Choose input method:", ("Upload Files", "Enter Directory Path"))
if input_method == "Upload Files":
uploaded_files = st.file_uploader("Drag and drop files here", accept_multiple_files=True)
if uploaded_files and st.button("Process Files"):
with st.spinner("Processing files..."):
documents = load_multimodal_data(uploaded_files)
st.session_state['index'] = create_index(documents)
st.session_state['history'] = []
st.success("Files processed and index created!")
else:
directory_path = st.text_input("Enter directory path:")
if directory_path and st.button("Process Directory"):
if os.path.isdir(directory_path):
with st.spinner("Processing directory..."):
documents = load_data_from_directory(directory_path)
st.session_state['index'] = create_index(documents)
st.session_state['history'] = []
st.success("Directory processed and index created!")
else:
st.error("Invalid directory path. Please enter a valid path.")
with col2:
if 'index' in st.session_state:
st.title("Chat")
if 'history' not in st.session_state:
st.session_state['history'] = []
query_engine = st.session_state['index'].as_query_engine(similarity_top_k=20, streaming=True)
user_input = st.chat_input("Enter your query:")
# Display chat messages
chat_container = st.container()
with chat_container:
for message in st.session_state['history']:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if user_input:
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
response = query_engine.query(user_input)
for token in response.response_gen:
full_response += token
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
# Check if the query is about visible text, location, or timestamp
if "visible text" in user_input.lower() or "location" in user_input.lower() or "timestamp" in user_input.lower():
default_response = generate_default_response()
full_response += "\n\n" + f"**Visible Text Extraction**: {default_response['Visible Text Extraction']}\n" \
f"**Inferred Location/Scene**: {default_response['Inferred Location/Scene']}\n" \
f"**Date/Time of Image**: {default_response['Date/Time of Image']}"
st.session_state['history'].append({"role": "assistant", "content": full_response})
# Add a clear button
if st.button("Clear Chat"):
st.session_state['history'] = []
st.rerun()
if __name__ == "__main__":
main() |