prompt-search-engine / prompt_search_engine.py
Anja97's picture
Initial commit with cleaned project files
adad4ac
raw
history blame contribute delete
706 Bytes
from typing import List, Tuple, Sequence
import numpy as np
from vectorizer import Vectorizer
from similarity import cosine_similarity
class PromptSearchEngine:
def __init__(self, prompts: Sequence[str], vectorizer: Vectorizer) -> None:
self.prompts = prompts
self.vectorizer = vectorizer
self.corpus_vectors = vectorizer.transform(prompts)
def most_similar(self, query, n = 5) -> List[Tuple[float, str]]:
query_vector = self.vectorizer.transform([query])[0]
similarities = cosine_similarity(query_vector, self.corpus_vectors)
top_indices = similarities.argsort()[-n:][::-1]
return [(similarities[i], self.prompts[i]) for i in top_indices]