from sentence_transformers import SentenceTransformer | |
import torch | |
class Model: | |
def __init__(self): | |
# Load the pre-trained model | |
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
def __call__(self, payload): | |
# Extract text chunks from the payload | |
chunks = payload.get("inputs", []) | |
# Generate embeddings | |
embeddings = self.embedding_model.encode(chunks, convert_to_tensor=True) | |
# Prepare response | |
response = { | |
"embeddings": embeddings.tolist(), # Convert tensor to list for JSON serialization | |
"shape": list(embeddings.shape) # Return the shape of the embeddings tensor | |
} | |
return response | |