import gradio as gr import torch from datasets import load_dataset from qdrant_client import QdrantClient from qdrant_client.http import models from colpali_engine.models import ColQwen2, ColQwen2Processor from PIL import Image import requests from io import BytesIO # Initialize the model, processor, and Qdrant client model_name = "vidore/colqwen2-v0.1" colpali_model = ColQwen2.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cuda:0") colpali_processor = ColQwen2Processor.from_pretrained(model_name) qdrant_client = QdrantClient(":memory:") collection_name = "image_collection" # Load the dataset (this should be done only once when setting up the app) dataset = load_dataset("davanstrien/loc-nineteenth-century-song-sheets", split="train") def setup_qdrant(): # Create a collection in Qdrant qdrant_client.recreate_collection( collection_name=collection_name, vectors_config=models.VectorParams( size=colpali_model.config.hidden_size, distance=models.Distance.COSINE, multivector_config=models.MultiVectorConfig( comparator=models.MultiVectorComparator.MAX_SIM ), ), ) # Index the dataset (this should be done only once when setting up the app) batch_size = 32 for i in range(0, len(dataset), batch_size): batch = dataset[i:i+batch_size] images = batch['image'] with torch.no_grad(): batch_images = colpali_processor.process_images(images).to(colpali_model.device) image_embeddings = colpali_model(**batch_images) points = [] for j, embedding in enumerate(image_embeddings): multivector = embedding.cpu().float().numpy().tolist() points.append(models.PointStruct( id=i+j, vector=multivector, payload={ "item_id": batch['item_id'][j], "item_url": batch['item_url'][j] } )) qdrant_client.upsert( collection_name=collection_name, points=points ) print("Indexing complete!") def search_similar_images(query, top_k=5, mode="text"): with torch.no_grad(): if mode == "text": batch_query = colpali_processor.process_queries([query]).to(colpali_model.device) else: # Image mode batch_query = colpali_processor.process_images([query]).to(colpali_model.device) query_embedding = colpali_model(**batch_query) multivector_query = query_embedding[0].cpu().float().numpy().tolist() search_result = qdrant_client.search( collection_name=collection_name, query_vector=multivector_query, limit=top_k ) return search_result def process_results(results): output = [] for result in results: item_url = result.payload['item_url'] score = result.score output.append((item_url, f"Score: {score:.4f}")) return output def text_search(query, top_k): results = search_similar_images(query, top_k, mode="text") return process_results(results) def image_search(image, top_k): results = search_similar_images(image, top_k, mode="image") return process_results(results) # Set up the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Image Search App") gr.Markdown("Search for similar images using text or image input.") with gr.Tab("Text Search"): text_input = gr.Textbox(label="Enter your search query") text_button = gr.Button("Search") text_output = gr.Gallery(label="Results", show_label=False, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto") text_scores = gr.JSON(label="Scores") with gr.Tab("Image Search"): image_input = gr.Image(type="pil", label="Upload an image") image_button = gr.Button("Search") image_output = gr.Gallery(label="Results", show_label=False, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto") image_scores = gr.JSON(label="Scores") top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results") text_button.click(text_search, inputs=[text_input, top_k_slider], outputs=[text_output, text_scores]) image_button.click(image_search, inputs=[image_input, top_k_slider], outputs=[image_output, image_scores]) # Run the setup (this should be done only once when deploying the app) setup_qdrant() # Launch the app demo.launch()