File size: 503 Bytes
99fa459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import torch


def predict(
    query: str,
    corpus_embeddings: torch.Tensor,
    corpus_labels: list,
    model: SentenceTransformer,
    top_k: int = 5,
) -> list:
    query_embedding = model.encode([query])
    result = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
    result_predictions: list = [corpus_labels[el["corpus_id"]] for el in result[0]]
    return result_predictions