Spaces:
Sleeping
Sleeping
import os | |
import json | |
import numpy as np | |
import faiss | |
import torch | |
import torch.nn.functional as F | |
from torch.cuda.amp import autocast | |
from transformers import AutoTokenizer, AutoModel | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
class FaissSearch: | |
def __init__(self, model_path, index_path, index_keys_path, filtered_db_path, device='cuda:0'): | |
self.device = device | |
self.model_path = model_path | |
self.index = faiss.read_index(index_path) | |
self.max_len = 512 | |
with open(index_keys_path, 'r', encoding='utf-8') as f: | |
self.index_keys = json.load(f) | |
with open(filtered_db_path, 'r', encoding='utf-8') as f: | |
self.filtered_db_data = json.load(f) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.model = None | |
def _load_model(self): | |
if self.model is None: | |
self.model = AutoModel.from_pretrained(self.model_path).to(self.device) | |
def _query_tokenization(self, text): | |
#text = "query: " + text # if using e5 model | |
text = text | |
tokens = self.tokenizer( | |
text, | |
return_tensors="pt", | |
padding='max_length', | |
truncation=True, | |
max_length=self.max_len | |
) | |
return tokens | |
def _query_embed_extraction(self, tokens, do_normalization=True): | |
self._load_model() | |
self.model.eval() | |
with torch.no_grad(): | |
with autocast(): | |
inputs = {k: v.to(self.device) for k, v in tokens.items()} | |
outputs = self.model(**inputs) | |
embedding = outputs.last_hidden_state[:, 0].cpu() | |
if do_normalization: | |
embedding = F.normalize(embedding, dim=-1) | |
return embedding.numpy() | |
def _search_results_filtering(self, preds, dists): | |
sorted_values = [(ref, score) for ref, score in zip(preds, dists)] | |
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True) | |
sorted_preds = [x[0] for x in sorted_values] | |
sorted_scores = [x[1] for x in sorted_values] | |
return sorted_preds, sorted_scores | |
def search(self, query, top=20): | |
query_tokens = self._query_tokenization(query) | |
query_embeds = self._query_embed_extraction(query_tokens, do_normalization=True) | |
distances, indices = self.index.search(query_embeds, len(self.filtered_db_data)) | |
preds = [self.index_keys[str(x)] for x in indices[0]] | |
preds, scores = self._search_results_filtering(preds, distances[0]) | |
docs = [self.filtered_db_data[ref] for ref in preds] | |
torch.cuda.empty_cache() | |
return preds[:top], docs[:top] | |
STEP = 5000 | |
model_path = os.environ.get("MODEL_PATH", "bge/") | |
index_path = f"faiss_indexes/faiss__bge_{STEP}.index" | |
index_keys_path = f"faiss_indexes/index_keys__bge_{STEP}.json" | |
filtered_db_path = f"faiss_indexes/filtered_db_data__bge_{STEP}.json" | |
searcher = FaissSearch(model_path, index_path, index_keys_path, filtered_db_path, os.environ.get("DEVICE", "cuda:0")) | |
app = FastAPI() | |
class SearchRequest(BaseModel): | |
query: str | |
top: int = 10 | |
class SearchResponse(BaseModel): | |
predictions: list | |
documents: list | |
async def search_endpoint(request: SearchRequest): | |
try: | |
preds, docs = searcher.search(request.query, top=request.top) | |
return SearchResponse(predictions=preds, documents=docs) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |