Spaces:
Running
Running
# ---------------------------------------------------------------------------- | |
# Copyright (c) 2024 Amar Ali-bey | |
# | |
# OpenVPRLab: https://github.com/amaralibey/nanoCLIP | |
# | |
# Licensed under the MIT License. See LICENSE file in the project root. | |
# ---------------------------------------------------------------------------- | |
from pathlib import Path | |
from typing import List, Tuple, Optional | |
import torch | |
import torch.nn.functional as F | |
import faiss | |
from transformers import AutoTokenizer | |
import gradio as gr | |
from text_encoder import TextEncoder | |
from load_album import AlbumDataset | |
class ImageSearchEngine: | |
def __init__( | |
self, | |
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", | |
output_dim: int = 64, | |
gallery_folder: str = "photos", | |
device: str = 'cpu' | |
): | |
if device == 'cuda' and not torch.cuda.is_available(): | |
print("CUDA is not available. Using CPU instead.") | |
device = 'cpu' | |
self.device = torch.device(device) | |
self.setup_model(model_name, output_dim) | |
self.setup_gallery(gallery_folder) | |
def setup_model(self, model_name: str, output_dim: int) -> None: | |
"""Initialize and load the text encoder model.""" | |
self.txt_encoder = TextEncoder( | |
output_dim=output_dim, | |
lang_model=model_name | |
).to(self.device) | |
# Load the pre-trained weights for the text encoder | |
# | |
weights_path = Path(__file__).parent.resolve() / 'txt_encoder_state_dict.pth' | |
# check if the weights file exists | |
if not weights_path.exists(): | |
raise FileNotFoundError(f"Text encoder weights not found: {weights_path}, make sure to run the create_index.py script.") | |
weights = torch.load(weights_path, map_location=self.device, weights_only=True) | |
self.txt_encoder.load_state_dict(weights) | |
self.txt_encoder.eval() | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def setup_gallery(self, gallery_folder: str) -> None: | |
"""Setup the image gallery and FAISS index.""" | |
gallery_path = Path(__file__).parent.resolve() / f'gallery/{gallery_folder}' | |
# check if the gallery folder exists | |
if not gallery_path.exists(): | |
raise FileNotFoundError(f"Album folder {gallery_path} not found") | |
# we use the AlbumDataset class to load the image paths (we won't load the images themselves) | |
# this is more efficient than loading the images directly, because Gradio will load them | |
# given the paths returned by the search method. | |
self.dataset = AlbumDataset(gallery_path, transform=None) | |
# Load the FAISS index | |
# the index file should be in the same folder as the gallery | |
# and has the same name as the folder being indexed | |
index_path = gallery_path.parent / f"{gallery_folder}.faiss" | |
self.index = faiss.read_index(index_path.as_posix()) | |
def encode_query(self, query_text: str) -> torch.Tensor: | |
"""Encode the text query into embeddings.""" | |
inputs = self.tokenizer(query_text, truncation=True, return_tensors="pt") | |
inputs = inputs['input_ids'].to(self.device) | |
embedding = self.txt_encoder(inputs) | |
embedding = F.normalize(embedding, p=2, dim=1) | |
return embedding.cpu() | |
def search(self, query_text: str, k: int = 20) -> List[Tuple[str, Optional[str]]]: | |
"""Search for images matching the query text.""" | |
if len(query_text) < 3: # avoid searching for very short queries | |
return [] | |
query_embedding = self.encode_query(query_text) | |
dist, indices = self.index.search(query_embedding, k) | |
# you can filter results according to a threshold on the distance | |
return [(self.dataset.imgs[idx], None) for idx in indices[0]] | |
class GalleryUI: | |
def __init__(self, search_engine: ImageSearchEngine): | |
self.search_engine = search_engine | |
self.css_path = Path(__file__).parent / 'style.css' | |
def load_css(self) -> str: | |
"""Load CSS styles from file.""" | |
with open(self.css_path) as f: | |
return f.read() | |
def create_interface(self) -> gr.Blocks: | |
"""Create the Gradio interface.""" | |
with gr.Blocks(css=self.load_css(), theme=gr.themes.Soft(text_size='lg')) as demo: | |
with gr.Column(elem_classes="container"): | |
self._create_header() | |
self._create_search_section() | |
self._create_footer() | |
self._setup_callbacks(demo) | |
return demo | |
def _create_header(self) -> None: | |
"""Create the header section.""" | |
with gr.Column(elem_classes="header"): | |
gr.Markdown("# Gallery Search") | |
gr.Markdown("Search through your collection of photos with AI") | |
def _create_search_section(self) -> None: | |
"""Create the search interface section.""" | |
with gr.Column(): | |
self.query_text = gr.Textbox( | |
placeholder="Example: Riding my horse", | |
label="Search Query", | |
elem_classes="search-input", | |
autofocus=True, | |
container=False, | |
interactive=True | |
) | |
self.gallery = gr.Gallery( | |
label="Search Results", | |
columns=[4], | |
height=600, | |
object_fit="cover", | |
elem_classes="gallery", | |
container=False, | |
) | |
def _create_footer(self) -> None: | |
"""Create the footer section.""" | |
with gr.Column(elem_classes="footer"): | |
gr.Markdown( | |
"""Created by [Amar Ali-bey](https://amaralibey.github.io) | | |
[View on GitHub](https://github.com/amaralibey/nanoCLIP)""" | |
) | |
def _setup_callbacks(self, demo: gr.Blocks) -> None: | |
"""Setup the interface callbacks.""" | |
self.query_text.submit( | |
self.search_engine.search, | |
inputs=[self.query_text],#, self.number_of_results], | |
outputs=self.gallery, | |
show_progress='hidden', | |
) | |
search_engine = ImageSearchEngine( | |
model_name = "sentence-transformers/all-MiniLM-L6-v2", | |
output_dim = 64, | |
gallery_folder = "photos", | |
) | |
ui = GalleryUI(search_engine) | |
demo = ui.create_interface() | |
if __name__ == "__main__": | |
demo.launch() | |