Spaces:
Sleeping
Sleeping
File size: 7,117 Bytes
19a6fbb 67e3cab fe5df96 67e3cab 19a6fbb fe5df96 19a6fbb 67e3cab 19a6fbb a865eda 67e3cab a865eda 19a6fbb 67e3cab 19a6fbb 67e3cab 19a6fbb fe5df96 19a6fbb fe5df96 67e3cab fe5df96 67e3cab fe5df96 19a6fbb 7d28c9b 19a6fbb 7d28c9b 19a6fbb fe5df96 7d28c9b fe5df96 7d28c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import clip
import os
import pandas as pd
from PIL import Image
import streamlit as st
import sys
import torch
from vectordb import search_image_index, search_text_index, search_image_index_with_image, search_text_index_with_image
from utils import load_image_index, load_text_index, load_audio_index, get_local_files
from data_search import adapter_utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
def data_search(clip_model, preprocess, text_embedding_model, whisper_model, device):
@st.cache_resource
def load_finetuned_model(file_name):
model, preprocess = clip.load("ViT-B/32", device=device)
model.load_state_dict(torch.load(f"annotations/{file_name}/finetuned_model.pt", weights_only=True))
return model, preprocess
@st.cache_resource
def load_adapter():
adapter = adapter_utils.load_adapter_model()
return adapter
st.title("Data Search")
images = os.listdir("images/")
if images == []:
st.warning("No Images Found! Please upload images to the database.")
return
annotation_projects = get_local_files("annotations/", get_details=True)
if annotation_projects or st.session_state.get('selected_annotation_project', None) is not None:
annotation_projects_with_model = []
for annotation_project in annotation_projects:
if os.path.exists(f"annotations/{annotation_project['file_name']}/finetuned_model.pt"):
annotation_projects_with_model.append(annotation_project)
if annotation_projects_with_model or st.session_state.get('selected_annotation_project', None) is not None:
if st.button("Use Default Model"):
st.session_state.pop('selected_annotation_project', None)
annotation_projects_df = pd.DataFrame(annotation_projects_with_model)
annotation_projects_df['file_created'] = annotation_projects_df['file_created'].dt.strftime("%Y-%m-%d %H:%M:%S")
annotation_projects_df['display_text'] = annotation_projects_df.apply(lambda x: f"ID: {x['file_name']} | Time Created: ({x['file_created']})", axis=1)
annotation_project = st.selectbox("Select Annotation Project", options=annotation_projects_df['display_text'].tolist())
annotation_project = annotation_projects_df[annotation_projects_df['display_text'] == annotation_project].iloc[0]
if st.button("Use Selected Fine-Tuned Model") or st.session_state.get('selected_annotation_project', None) is not None:
with st.spinner("Loading Fine-Tuned Model..."):
st.session_state['selected_annotation_project'] = annotation_project
clip_model, preprocess = load_finetuned_model(annotation_project['file_name'])
st.info(f"Using Fine-Tuned Model from {annotation_project['file_name']}")
else:
st.info("Using Default Model")
adapter = load_adapter()
adapter.to(device)
text_input = st.text_input("Search Database")
image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
if st.button("Search", disabled=text_input.strip() == "" and image_input is None):
if os.path.exists("./vectorstore/image_index.index"):
image_index, image_data = load_image_index()
if os.path.exists("./vectorstore/text_index.index"):
text_index, text_data = load_text_index()
if os.path.exists("./vectorstore/audio_index.index"):
audio_index, audio_data = load_audio_index()
with torch.no_grad():
if not os.path.exists("./vectorstore/image_data.csv"):
st.warning("No Image Index Found. So not searching for images.")
image_index = None
if not os.path.exists("./vectorstore/text_data.csv"):
st.warning("No Text Index Found. So not searching for text.")
text_index = None
if not os.path.exists("./vectorstore/audio_data.csv"):
st.warning("No Audio Index Found. So not searching for audio.")
if image_input:
image = Image.open(image_input)
image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = clip_model.encode_image(image)
adapted_text_embeddings = adapter(image_features)
if image_index is not None:
image_indices = search_image_index_with_image(image_features, image_index, clip_model, k=3)
if text_index is not None:
text_indices = search_text_index_with_image(adapted_text_embeddings, text_index, text_embedding_model, k=3)
if audio_index is not None:
audio_indices = search_text_index_with_image(adapted_text_embeddings, audio_index, text_embedding_model, k=3)
else:
if image_index is not None:
image_indices = search_image_index(text_input, image_index, clip_model, k=3)
if text_index is not None:
text_indices = search_text_index(text_input, text_index, text_embedding_model, k=3)
if audio_index is not None:
audio_indices = search_text_index(text_input, audio_index, text_embedding_model, k=3)
if not image_index and not text_index and not audio_index:
st.error("No Data Found! Please add data to the database.")
st.subheader("Image Results")
cols = st.columns(3)
for i in range(3):
with cols[i]:
if image_index:
image_path = image_data['path'].iloc[image_indices[0][i]]
image = Image.open(image_path)
image = preprocess(image).unsqueeze(0).to(device)
text = clip.tokenize([text_input]).to(device)
image_features = clip_model.encode_image(image)
text_features = clip_model.encode_text(text)
cosine_similarity = torch.cosine_similarity(image_features, text_features)
st.write(f"Similarity: {cosine_similarity.item() * 100:.2f}%")
st.image(image_path)
st.subheader("Text Results")
cols = st.columns(3)
for i in range(3):
with cols[i]:
if text_index:
text_content = text_data['content'].iloc[text_indices[0][i]]
st.write(text_content)
st.subheader("Audio Results")
cols = st.columns(3)
for i in range(3):
with cols[i]:
if audio_index:
audio_path = audio_data['path'].iloc[audio_indices[0][i]]
audio_content = audio_data['content'].iloc[audio_indices[0][i]]
st.audio(audio_path)
st.write(f"_{audio_content}_") |