# Description: Using Facebook's Faiss library to perform semantic search according to the query # Reference: https://deepnote.com/blog/semantic-search-using-faiss-and-mpnet from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F import faiss class SemanticEmbedding: def __init__(self, model_name="sentence-transformers/all-mpnet-base-v2"): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) # Mean Pooling - Take attention mask into account for correct averaging def mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[ 0 ] # First element of model_output contains all token embeddings input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def get_embedding(self, sentences): # Tokenize sentences encoded_input = self.tokenizer( sentences, padding=True, truncation=True, return_tensors="pt" ) with torch.no_grad(): model_output = self.model(**encoded_input) # Perform pooling sentence_embeddings = self.mean_pooling( model_output, encoded_input["attention_mask"] ) # Normalize embeddings sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings.detach().numpy() class FaissForQuerySearch: def __init__(self, model, dim=768): self.index = faiss.IndexFlatIP(dim) # Maintaining the document data self.doc_map = dict() self.model = model self.ctr = 0 self.uuid = [] self.labels = [] def search_query(self, query, k=1): D, I = self.index.search(self.model.get_embedding(query), k) return [ self.labels[idx] for idx, score in zip(I[0], D[0]) if idx in self.doc_map ] def add_summary(self, document_text, id, predicted_label): self.index.add((self.model.get_embedding(document_text))) # index self.uuid.append(id) # appending the uuid self.labels.append(predicted_label) # appending the predicted label self.doc_map[self.ctr] = document_text # store the original document text self.ctr += 1