File size: 3,089 Bytes
67d226a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3392fee
 
 
 
 
 
 
 
 
 
 
67d226a
 
3392fee
 
 
 
67d226a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359e021
 
 
 
 
 
 
 
 
 
 
 
510d97a
 
 
359e021
67d226a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class SearchMood:
    def __init__(self, mood_prompt, prior_init):
        self.prior_init = prior_init
        self.mood_prompt = mood_prompt
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        self.embeding = lambda mood_prompt, mood_state: (self.model.encode(mood_prompt, convert_to_tensor=True), self.model.encode(mood_state, convert_to_tensor=True))
        self.similar = lambda similarx, similary: util.pytorch_cos_sim(similarx, similary)
        self.cx_sample = shelve.open('cx_sample.db')['sample']
        self.database = shelve.open('database.db')
        SearchMood.prior_component = torch.tensor([0.,1.])
        self.prior_sample = torch.normal(self.prior_component[0], self.prior_component[1], size=(5,))
        self.sample_losses = None
        
        





    def Hierarchical(self, data):
        # Load the data.

        # Perform hierarchical clustering.
        clusters, distances = hierarchical(data, distance='cosine', linkage='complete', return_clusters=True)

        # Print the cluster labels.
        print(clusters)

    def embedings(self, samplex, sampley):
        emb = self.embeding(samplex, sampley)
        embHx = self.Hierarchical(emb[0])
        embHy = self.Hierarchical(emb[1])
        similarity = self.similar(embHx, embHy)

        return(similarity)



    def mood_dist(self, data_sample=False, mood_prompt=False, search=True):
        cx_index = []
        if search == True:
           for mood_state in self.cx_sample:
               index_sample = []
               max_sample  = 0
               index_sample = 0
               for index, mood_prompts in enumerate(self.database['database']):
                   simemb  = self.embedings(mood_state, mood_prompts)
                   if max_sample < simemb:
                       max_sample = simemb
                       index_sample = index
                      
                   
               cx_index.append((float(index_sample)))


        else:
            cx_index.append(self.embedings(mood_prompt, data_sample))

        
        
        return(torch.tensor(cx_index))



    def loss_fn(self):
        
        for sample in self.prior_sample:
            sample = sample.item()
            data_sample = self.database['database'][round(sample)]
            samp_loss = self.mood_dist(data_sample, self.mood_prompt, search=False)
            print(samp_loss)
            if samp_loss.item() >= 1.:
                print('test')
                break



        return(torch.tensor([samp_loss*-1]))



    def search_compose(self):
        for d in range(100):
            optimizer = optim.Adagrad((self.prior_component[0], self.prior_component[1]))
            optimizer.step(closure=self.loss_fn)
            state_dict = optimizer.state_dict()        
            params = state_dict['param_groups'][0]['params']

            self.prior_component[0] = params[0]
            self.prior_component[1] = params[1]
            self.prior_sample = torch.normal(self.prior_component[0], self.prior_component[1], size=(5,))