nanoCLIP / app.py
amaralibey's picture
Update app.py
ba6de20 verified
raw
history blame
6.57 kB
# ----------------------------------------------------------------------------
# Copyright (c) 2024 Amar Ali-bey
#
# OpenVPRLab: https://github.com/amaralibey/nanoCLIP
#
# Licensed under the MIT License. See LICENSE file in the project root.
# ----------------------------------------------------------------------------
from pathlib import Path
from typing import List, Tuple, Optional
import torch
import torch.nn.functional as F
import faiss
from transformers import AutoTokenizer
import gradio as gr
from text_encoder import TextEncoder
from load_album import AlbumDataset
class ImageSearchEngine:
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
output_dim: int = 64,
gallery_folder: str = "photos",
device: str = 'cpu'
):
if device == 'cuda' and not torch.cuda.is_available():
print("CUDA is not available. Using CPU instead.")
device = 'cpu'
self.device = torch.device(device)
self.setup_model(model_name, output_dim)
self.setup_gallery(gallery_folder)
def setup_model(self, model_name: str, output_dim: int) -> None:
"""Initialize and load the text encoder model."""
self.txt_encoder = TextEncoder(
output_dim=output_dim,
lang_model=model_name
).to(self.device)
# Load the pre-trained weights for the text encoder
#
weights_path = Path(__file__).parent.resolve() / 'txt_encoder_state_dict.pth'
# check if the weights file exists
if not weights_path.exists():
raise FileNotFoundError(f"Text encoder weights not found: {weights_path}, make sure to run the create_index.py script.")
weights = torch.load(weights_path, map_location=self.device, weights_only=True)
self.txt_encoder.load_state_dict(weights)
self.txt_encoder.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def setup_gallery(self, gallery_folder: str) -> None:
"""Setup the image gallery and FAISS index."""
gallery_path = Path(__file__).parent.resolve() / f'gallery/{gallery_folder}'
# check if the gallery folder exists
if not gallery_path.exists():
raise FileNotFoundError(f"Album folder {gallery_path} not found")
# we use the AlbumDataset class to load the image paths (we won't load the images themselves)
# this is more efficient than loading the images directly, because Gradio will load them
# given the paths returned by the search method.
self.dataset = AlbumDataset(gallery_path, transform=None)
# Load the FAISS index
# the index file should be in the same folder as the gallery
# and has the same name as the folder being indexed
index_path = gallery_path.parent / f"{gallery_folder}.faiss"
self.index = faiss.read_index(index_path.as_posix())
@torch.no_grad()
def encode_query(self, query_text: str) -> torch.Tensor:
"""Encode the text query into embeddings."""
inputs = self.tokenizer(query_text, truncation=True, return_tensors="pt")
inputs = inputs['input_ids'].to(self.device)
embedding = self.txt_encoder(inputs)
embedding = F.normalize(embedding, p=2, dim=1)
return embedding.cpu()
def search(self, query_text: str, k: int = 20) -> List[Tuple[str, Optional[str]]]:
"""Search for images matching the query text."""
if len(query_text) < 3: # avoid searching for very short queries
return []
query_embedding = self.encode_query(query_text)
dist, indices = self.index.search(query_embedding, k)
# you can filter results according to a threshold on the distance
return [(self.dataset.imgs[idx], None) for idx in indices[0]]
class GalleryUI:
def __init__(self, search_engine: ImageSearchEngine):
self.search_engine = search_engine
self.css_path = Path(__file__).parent / 'style.css'
def load_css(self) -> str:
"""Load CSS styles from file."""
with open(self.css_path) as f:
return f.read()
def create_interface(self) -> gr.Blocks:
"""Create the Gradio interface."""
with gr.Blocks(css=self.load_css(), theme=gr.themes.Soft(text_size='lg')) as demo:
with gr.Column(elem_classes="container"):
self._create_header()
self._create_search_section()
self._create_footer()
self._setup_callbacks(demo)
return demo
def _create_header(self) -> None:
"""Create the header section."""
with gr.Column(elem_classes="header"):
gr.Markdown("# Gallery Search")
gr.Markdown("Search through your collection of photos with AI")
def _create_search_section(self) -> None:
"""Create the search interface section."""
with gr.Column():
self.query_text = gr.Textbox(
placeholder="Example: Riding my horse",
label="Search Query",
elem_classes="search-input",
autofocus=True,
container=False,
interactive=True
)
self.gallery = gr.Gallery(
label="Search Results",
columns=[4],
height=600,
object_fit="cover",
elem_classes="gallery",
container=False,
)
def _create_footer(self) -> None:
"""Create the footer section."""
with gr.Column(elem_classes="footer"):
gr.Markdown(
"""Created by [Amar Ali-bey](https://amaralibey.github.io) |
[View on GitHub](https://github.com/amaralibey/nanoCLIP)"""
)
def _setup_callbacks(self, demo: gr.Blocks) -> None:
"""Setup the interface callbacks."""
self.query_text.submit(
self.search_engine.search,
inputs=[self.query_text],#, self.number_of_results],
outputs=self.gallery,
show_progress='hidden',
)
search_engine = ImageSearchEngine(
model_name = "sentence-transformers/all-MiniLM-L6-v2",
output_dim = 64,
gallery_folder = "photos",
)
ui = GalleryUI(search_engine)
demo = ui.create_interface()
if __name__ == "__main__":
demo.launch()