InferencetrainingAI commited on
Commit
67d226a
1 Parent(s): 148ae38

SER (Small Efficient Retrieval) init commit

Browse files
Files changed (1) hide show
  1. ser.py +67 -0
ser.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class SearchMood:
2
+ def __init__(self, mood_prompt, prior_init):
3
+ self.prior_init = prior_init
4
+ self.mood_prompt = mood_prompt
5
+ self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
6
+ self.embeding = lambda mood_prompt, mood_state: (self.model.encode(mood_prompt, convert_to_tensor=True), self.model.encode(mood_state, convert_to_tensor=True))
7
+ self.similar = lambda similarx, similary: util.pytorch_cos_sim(similarx, similary)
8
+ self.cx_sample = shelve.open('cx_sample.db')['sample']
9
+ self.database = shelve.open('database.db')
10
+ SearchMood.prior_component = torch.tensor([0.,1.])
11
+ self.prior_sample = torch.normal(self.prior_component[0], self.prior_component[1], size=(5,))
12
+ self.sample_losses = None
13
+
14
+
15
+
16
+
17
+
18
+ def embedings(self, samplex, sampley):
19
+ emb = self.embeding(samplex, sampley)
20
+ similarity = self.similar(emb[0], emb[1])
21
+ return(similarity)
22
+
23
+
24
+
25
+ def mood_dist(self, data_sample=False, mood_prompt=False, search=True):
26
+ cx_index = []
27
+ if search == True:
28
+ for mood_state in self.cx_sample:
29
+ index_sample = []
30
+ max_sample = 0
31
+ index_sample = 0
32
+ for index, mood_prompts in enumerate(self.database['database']):
33
+ simemb = self.embedings(mood_state, mood_prompts)
34
+ if max_sample < simemb:
35
+ max_sample = simemb
36
+ index_sample = index
37
+
38
+
39
+ cx_index.append((float(index_sample)))
40
+
41
+
42
+ else:
43
+ cx_index.append(self.embedings(mood_prompt, data_sample))
44
+
45
+
46
+
47
+ return(torch.tensor(cx_index))
48
+
49
+
50
+
51
+ def loss_fn(self):
52
+
53
+ for sample in self.prior_sample:
54
+ sample = sample.item()
55
+ data_sample = self.database['database'][round(sample)]
56
+ samp_loss = self.mood_dist(data_sample, self.mood_prompt, search=False)
57
+ print(samp_loss)
58
+ if samp_loss.item() >= 1.:
59
+ print('test')
60
+ break
61
+
62
+
63
+
64
+ return(torch.tensor([samp_loss*-1]))
65
+
66
+
67
+