File size: 2,424 Bytes
6721647
 
 
3bd7562
 
 
 
6721647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
---
language:
- en
library_name: transformers
datasets:
- facebook/anli
- zen-E/ANLI-simcse-roberta-large-embeddings-pca-256
metrics:
- spearmanr
- pearsonr
---

The model is trained by knowledge distillation between the "princeton-nlp/unsup-simcse-roberta-large" and "zen-E/bert-mini-sentence-distil-unsupervised" on the 'ANLI'.

The model can perform inferencing by Automodel.

The model achieves 0.836 and 0.840 for pearsonr and spearmanr respectively on STS-b test dataset. 

For more training detail, the training config and the pytorch forward function is as follows. The teacher's fearure is first transform to a size of 256 by the PCA object in "zen-E/bert-mini-sentence-distil-unsupervised" which can be loaded by:

```python
import joblib
pca = joblib.load('ANLI-simcse-roberta-large-embeddings-pca-256/pca_model.sav')
features_256 = pca.transform(features)
```

```python
config = {
  'epoch' = 10,
  'learning_rate' = 5e-5,
  'batch_size' = 512,
  'temperature' = 0.05
}
```

```python
  def forward_cos_mse_kd(self, sentence1s, sentence2s, sentence3s, teacher_sentence1_embs, teacher_sentence2_embs, teacher_sentence3_embs):
    """forward function for the ANLI dataset"""
    _, o1 = self.bert(**sentence1s)
    _, o2 = self.bert(**sentence2s)
    _, o3 = self.bert(**sentence3s)

    # compute student's cosine similarity between sentences
    cos_o1_o2 = cosine_sim(o1, o2)
    cos_o1_o3 = cosine_sim(o1, o3)

    # compute teacher's cosine similarity between sentences
    cos_o1_o2_t = cosine_sim(teacher_sentence1_embs, teacher_sentence2_embs)
    cos_o1_o3_t = cosine_sim(teacher_sentence1_embs, teacher_sentence3_embs)

    cos_sim = torch.cat((cos_o1_o2, cos_o1_o3), dim=-1)
    cos_sim_t = torch.cat((cos_o1_o2_t, cos_o1_o3_t), dim=-1)

    # KL Divergence between student and teacher probabilities
    soft_teacher_probs = F.softmax(cos_sim_t / self.temperature, dim=1)
    kd_cos_loss = F.kl_div(F.log_softmax(cos_sim / self.temperature, dim=1),
                            soft_teacher_probs,
                            reduction='batchmean')

    # mse loss
    o = torch.cat([o1, o2, o3], dim=0)
    teacher_embs = torch.cat([teacher_sentence1_embs, teacher_sentence2_embs, teacher_sentence3_embs], dim=0)
    kd_mse_loss = nn.MSELoss()(o, teacher_embs)/3

    # equal weight for the two losses
    total_loss = kd_cos_loss*0.5+kd_mse_loss*0.5
    return total_loss, kd_cos_loss, kd_mse_loss
```