InferencetrainingAI
commited on
Commit
•
67d226a
1
Parent(s):
148ae38
SER (Small Efficient Retrieval) init commit
Browse files
ser.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class SearchMood:
|
2 |
+
def __init__(self, mood_prompt, prior_init):
|
3 |
+
self.prior_init = prior_init
|
4 |
+
self.mood_prompt = mood_prompt
|
5 |
+
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
6 |
+
self.embeding = lambda mood_prompt, mood_state: (self.model.encode(mood_prompt, convert_to_tensor=True), self.model.encode(mood_state, convert_to_tensor=True))
|
7 |
+
self.similar = lambda similarx, similary: util.pytorch_cos_sim(similarx, similary)
|
8 |
+
self.cx_sample = shelve.open('cx_sample.db')['sample']
|
9 |
+
self.database = shelve.open('database.db')
|
10 |
+
SearchMood.prior_component = torch.tensor([0.,1.])
|
11 |
+
self.prior_sample = torch.normal(self.prior_component[0], self.prior_component[1], size=(5,))
|
12 |
+
self.sample_losses = None
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
def embedings(self, samplex, sampley):
|
19 |
+
emb = self.embeding(samplex, sampley)
|
20 |
+
similarity = self.similar(emb[0], emb[1])
|
21 |
+
return(similarity)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def mood_dist(self, data_sample=False, mood_prompt=False, search=True):
|
26 |
+
cx_index = []
|
27 |
+
if search == True:
|
28 |
+
for mood_state in self.cx_sample:
|
29 |
+
index_sample = []
|
30 |
+
max_sample = 0
|
31 |
+
index_sample = 0
|
32 |
+
for index, mood_prompts in enumerate(self.database['database']):
|
33 |
+
simemb = self.embedings(mood_state, mood_prompts)
|
34 |
+
if max_sample < simemb:
|
35 |
+
max_sample = simemb
|
36 |
+
index_sample = index
|
37 |
+
|
38 |
+
|
39 |
+
cx_index.append((float(index_sample)))
|
40 |
+
|
41 |
+
|
42 |
+
else:
|
43 |
+
cx_index.append(self.embedings(mood_prompt, data_sample))
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
return(torch.tensor(cx_index))
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def loss_fn(self):
|
52 |
+
|
53 |
+
for sample in self.prior_sample:
|
54 |
+
sample = sample.item()
|
55 |
+
data_sample = self.database['database'][round(sample)]
|
56 |
+
samp_loss = self.mood_dist(data_sample, self.mood_prompt, search=False)
|
57 |
+
print(samp_loss)
|
58 |
+
if samp_loss.item() >= 1.:
|
59 |
+
print('test')
|
60 |
+
break
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
return(torch.tensor([samp_loss*-1]))
|
65 |
+
|
66 |
+
|
67 |
+
|