File size: 3,089 Bytes
67d226a 3392fee 67d226a 3392fee 67d226a 359e021 510d97a 359e021 67d226a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
class SearchMood:
def __init__(self, mood_prompt, prior_init):
self.prior_init = prior_init
self.mood_prompt = mood_prompt
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
self.embeding = lambda mood_prompt, mood_state: (self.model.encode(mood_prompt, convert_to_tensor=True), self.model.encode(mood_state, convert_to_tensor=True))
self.similar = lambda similarx, similary: util.pytorch_cos_sim(similarx, similary)
self.cx_sample = shelve.open('cx_sample.db')['sample']
self.database = shelve.open('database.db')
SearchMood.prior_component = torch.tensor([0.,1.])
self.prior_sample = torch.normal(self.prior_component[0], self.prior_component[1], size=(5,))
self.sample_losses = None
def Hierarchical(self, data):
# Load the data.
# Perform hierarchical clustering.
clusters, distances = hierarchical(data, distance='cosine', linkage='complete', return_clusters=True)
# Print the cluster labels.
print(clusters)
def embedings(self, samplex, sampley):
emb = self.embeding(samplex, sampley)
embHx = self.Hierarchical(emb[0])
embHy = self.Hierarchical(emb[1])
similarity = self.similar(embHx, embHy)
return(similarity)
def mood_dist(self, data_sample=False, mood_prompt=False, search=True):
cx_index = []
if search == True:
for mood_state in self.cx_sample:
index_sample = []
max_sample = 0
index_sample = 0
for index, mood_prompts in enumerate(self.database['database']):
simemb = self.embedings(mood_state, mood_prompts)
if max_sample < simemb:
max_sample = simemb
index_sample = index
cx_index.append((float(index_sample)))
else:
cx_index.append(self.embedings(mood_prompt, data_sample))
return(torch.tensor(cx_index))
def loss_fn(self):
for sample in self.prior_sample:
sample = sample.item()
data_sample = self.database['database'][round(sample)]
samp_loss = self.mood_dist(data_sample, self.mood_prompt, search=False)
print(samp_loss)
if samp_loss.item() >= 1.:
print('test')
break
return(torch.tensor([samp_loss*-1]))
def search_compose(self):
for d in range(100):
optimizer = optim.Adagrad((self.prior_component[0], self.prior_component[1]))
optimizer.step(closure=self.loss_fn)
state_dict = optimizer.state_dict()
params = state_dict['param_groups'][0]['params']
self.prior_component[0] = params[0]
self.prior_component[1] = params[1]
self.prior_sample = torch.normal(self.prior_component[0], self.prior_component[1], size=(5,))
|