embeddings / model.py
Pavithiran's picture
Update model.py
6ad40b4 verified
raw
history blame contribute delete
719 Bytes
from sentence_transformers import SentenceTransformer
import torch
class Model:
def __init__(self):
# Load the pre-trained model
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, payload):
# Extract text chunks from the payload
chunks = payload.get("inputs", [])
# Generate embeddings
embeddings = self.embedding_model.encode(chunks, convert_to_tensor=True)
# Prepare response
response = {
"embeddings": embeddings.tolist(), # Convert tensor to list for JSON serialization
"shape": list(embeddings.shape) # Return the shape of the embeddings tensor
}
return response