Spaces:
Sleeping
Sleeping
File size: 503 Bytes
99fa459 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import torch
def predict(
query: str,
corpus_embeddings: torch.Tensor,
corpus_labels: list,
model: SentenceTransformer,
top_k: int = 5,
) -> list:
query_embedding = model.encode([query])
result = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
result_predictions: list = [corpus_labels[el["corpus_id"]] for el in result[0]]
return result_predictions
|