import os import sys from fastapi import FastAPI, File, UploadFile from fastapi.responses import RedirectResponse import gradio as gr import requests import uvicorn from typing import List import torch from pdf2image import convert_from_bytes from PIL import Image from torch.utils.data import DataLoader from transformers import AutoProcessor sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './colpali-main'))) from colpali_engine.models.paligemma_colbert_architecture import ColPali from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator from colpali_engine.utils.colpali_processing_utils import ( process_images, process_queries, ) app = FastAPI() # Load model model_name = "vidore/colpali" token = os.environ.get("HF_TOKEN") model = ColPali.from_pretrained( "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cpu", token = token).eval() model.load_adapter(model_name) processor = AutoProcessor.from_pretrained(model_name, token = token) device = "cuda:0" if torch.cuda.is_available() else "cpu" if device != model.device: model.to(device) mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) # In-memory storage ds = [] images = [] @app.get("/") def read_root(): return RedirectResponse(url="/docs") @app.post("/index") async def index(files: List[UploadFile] = File(...)): global ds, images images = [] ds = [] for file in files: content = await file.read() pdf_image_list = convert_from_bytes(content) images.extend(pdf_image_list) dataloader = DataLoader( images, batch_size=4, shuffle=False, collate_fn=lambda x: process_images(processor, x), ) for batch_doc in dataloader: with torch.no_grad(): batch_doc = {k: v.to(device) for k, v in batch_doc.items()} embeddings_doc = model(**batch_doc) ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) return {"message": f"Uploaded and converted {len(images)} pages"} @app.post("/search") async def search(query: str, k: int): qs = [] with torch.no_grad(): batch_query = process_queries(processor, [query], mock_image) batch_query = {k: v.to(device) for k, v in batch_query.items()} embeddings_query = model(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) retriever_evaluator = CustomEvaluator(is_multi_vector=True) scores = retriever_evaluator.evaluate(qs, ds) top_k_indices = scores.argsort(axis=1)[0][-k:][::-1] results = [{"page": idx, "image": "image_placeholder"} for idx in top_k_indices] return {"results": results} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)