LoomRAG / data_search /data_search_page.py
NotShrirang's picture
fix: images directory validation in data search page
a865eda
raw
history blame
4.89 kB
import clip
import os
import pandas as pd
from PIL import Image
import streamlit as st
import sys
import torch
from vectordb import search_image_index, search_text_index
from utils import load_image_index, load_text_index, get_local_files
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
def data_search(clip_model, preprocess, text_embedding_model, device):
@st.cache_resource
def load_finetuned_model(file_name):
model, preprocess = clip.load("ViT-B/32", device=device)
model.load_state_dict(torch.load(f"annotations/{file_name}/finetuned_model.pt", weights_only=True))
return model, preprocess
st.title("Data Search")
images = os.listdir("images/")
if images == []:
st.warning("No images found in the data directory.")
return
annotation_projects = get_local_files("annotations/", get_details=True)
if annotation_projects or st.session_state.get('selected_annotation_project', None) is not None:
annotation_projects_with_model = []
for annotation_project in annotation_projects:
if os.path.exists(f"annotations/{annotation_project['file_name']}/finetuned_model.pt"):
annotation_projects_with_model.append(annotation_project)
if annotation_projects_with_model or st.session_state.get('selected_annotation_project', None) is not None:
if st.button("Use Default Model"):
st.session_state.pop('selected_annotation_project', None)
annotation_projects_df = pd.DataFrame(annotation_projects_with_model)
annotation_projects_df['file_created'] = annotation_projects_df['file_created'].dt.strftime("%Y-%m-%d %H:%M:%S")
annotation_projects_df['display_text'] = annotation_projects_df.apply(lambda x: f"ID: {x['file_name']} | Time Created: ({x['file_created']})", axis=1)
annotation_project = st.selectbox("Select Annotation Project", options=annotation_projects_df['display_text'].tolist())
annotation_project = annotation_projects_df[annotation_projects_df['display_text'] == annotation_project].iloc[0]
if st.button("Use Selected Fine-Tuned Model") or st.session_state.get('selected_annotation_project', None) is not None:
with st.spinner("Loading Fine-Tuned Model..."):
st.session_state['selected_annotation_project'] = annotation_project
clip_model, preprocess = load_finetuned_model(annotation_project['file_name'])
st.info(f"Using Fine-Tuned Model from {annotation_project['file_name']}")
else:
st.info("Using Default Model")
text_input = st.text_input("Search Database")
if st.button("Search", disabled=text_input.strip() == ""):
if os.path.exists("./vectorstore/image_index.index"):
image_index, image_data = load_image_index()
if os.path.exists("./vectorstore/text_index.index"):
text_index, text_data = load_text_index()
with torch.no_grad():
if not os.path.exists("./vectorstore/image_data.csv"):
st.warning("No Image Index Found. So not searching for images.")
image_index = None
if not os.path.exists("./vectorstore/text_data.csv"):
st.warning("No Text Index Found. So not searching for text.")
text_index = None
if image_index is not None:
image_indices = search_image_index(text_input, image_index, clip_model, k=3)
if text_index is not None:
text_indices = search_text_index(text_input, text_index, text_embedding_model, k=3)
if not image_index and not text_index:
st.error("No Data Found! Please add data to the database.")
st.subheader("Top 3 Results")
cols = st.columns(3)
for i in range(3):
with cols[i]:
if image_index:
image_path = image_data['path'].iloc[image_indices[0][i]]
image = Image.open(image_path)
image = preprocess(image).unsqueeze(0).to(device)
text = clip.tokenize([text_input]).to(device)
image_features = clip_model.encode_image(image)
text_features = clip_model.encode_text(text)
cosine_similarity = torch.cosine_similarity(image_features, text_features)
st.write(f"Similarity: {cosine_similarity.item() * 100:.2f}%")
st.image(image_path)
cols = st.columns(3)
for i in range(3):
with cols[i]:
if text_index:
text_content = text_data['content'].iloc[text_indices[0][i]]
st.write(text_content)