Jokica17's picture
Added backend `app` module and core engine logic:
cd20a25
raw
history blame
1.8 kB
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import Sequence, List, Tuple
from app.vectorizer import Vectorizer
from app.scorer import cosine_similarity
class PromptSearchEngine:
def __init__(self, prompts: Sequence[str]) -> None:
"""
Initialize search engine by vectorizing prompt corpus.
Vectorized prompt corpus should be used to find the top n most
similar prompts w.r.t. user’s input prompt.
Args:
prompts: The sequence of raw prompts from the dataset.
"""
self.prompts = prompts
model = SentenceTransformer("all-MiniLM-L6-v2")
self.vectorizer = Vectorizer(model)
self.corpus_vectors = self.vectorizer.transform(prompts)
def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
"""
Return top n most similar prompts from corpus.
Input query prompt should be vectorized with chosen Vectorizer.
After that, use the cosine_similarity function to get the top n most similar prompts from the corpus.
Args:
query: The raw query prompt input from the user.
n: The number of similar prompts returned from the corpus.
Returns:
The list of top n most similar prompts from the corpus along
with similarity scores. Note that returned prompts are verbatim.
"""
query_vector = self.vectorizer.transform([query])
similarities = cosine_similarity(query_vector, self.corpus_vectors)
top_n_vectors_with_scores = np.argsort(similarities)[-n:][::-1]
# Convert similarities to Python float and return the top-n prompts
return [(float(similarities[i]), self.prompts[i]) for i in top_n_vectors_with_scores]