File size: 706 Bytes
adad4ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import List, Tuple, Sequence
import numpy as np
from vectorizer import Vectorizer
from similarity import cosine_similarity

class PromptSearchEngine:
    def __init__(self, prompts: Sequence[str], vectorizer: Vectorizer) -> None:
        self.prompts = prompts
        self.vectorizer = vectorizer
        self.corpus_vectors = vectorizer.transform(prompts)

    def most_similar(self, query, n = 5) -> List[Tuple[float, str]]:
        query_vector = self.vectorizer.transform([query])[0]
        similarities = cosine_similarity(query_vector, self.corpus_vectors)
        top_indices = similarities.argsort()[-n:][::-1]
        return [(similarities[i], self.prompts[i]) for i in top_indices]