Pavithiran commited on
Commit
ab631a4
·
verified ·
1 Parent(s): 0b92f74

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +25 -0
model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import torch
3
+
4
+ class Model:
5
+ def __init__(self):
6
+ # Load the pre-trained model
7
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
8
+
9
+ def __call__(self, payload):
10
+ # Extract inputs from the payload
11
+ inputs = payload.get("inputs", {})
12
+ source_sentence = inputs.get("source_sentence", "")
13
+ sentences = inputs.get("sentences", [])
14
+
15
+ # Combine source_sentence with sentences
16
+ chunks = [source_sentence] + sentences
17
+ # Generate embeddings
18
+ embeddings = self.embedding_model.encode(chunks, convert_to_tensor=True)
19
+
20
+ # Prepare response
21
+ response = {
22
+ "embeddings": embeddings.tolist(), # Convert tensor to list for JSON serialization
23
+ "shape": list(embeddings.shape) # Return the shape of the embeddings tensor
24
+ }
25
+ return response