miCSE / README.md
TJKlein's picture
Update README.md
0b01beb
|
raw
history blame
3.2 kB
metadata
license: apache-2.0

mutual information Contrastive Sentence Embedding (miCSE):

arXiv Language model of the pre-print arXiv paper titled: "miCSE: Mutual Information Contrastive Learning for Low-shot Sentence Embeddings"

The miCSE language model is trained for sentence similarity computation. Training the model imposes alignment between the attention pattern of different views (embeddings of augmentations) during contrastive learning. Learning sentence embeddings with miCSE entails enforcing the syntactic consistency across augmented views for every single sentence, making contrastive self-supervised learning more sample efficient. Sentence representations correspond to the embedding of the [CLS] token.

Model Usage

from transformers import AutoTokenizer, AutoModel
import torch.nn as nn

tokenizer = AutoTokenizer.from_pretrained("sap-ai-research/miCSE")

model = AutoModel.from_pretrained("sap-ai-research/miCSE")


# Encoding of sentences in a list with a predefined maximum lengths of tokens (max_length)

max_length = 32

sentences = [
    "This is a sentence for testing miCSE.", 
    "This is yet another test sentence for the mutual information Contrastive Sentence Embeddings model."
]

batch = tokenizer.batch_encode_plus(
                sentences,
                return_tensors='pt',
                padding=True,
                max_length=max_length,
                truncation=True
            )

# Compute the embeddings and keep only the _**[CLS]**_ embedding (the first token)

outputs = model(**batch, output_hidden_states=True, return_dict=True)

embeddings = outputs.last_hidden_state[:,0]

# Define similarity metric, e.g., cosine similarity

sim = nn.CosineSimilarity(dim=-1)

# Compute similarity between the **first** and the **second** sentence

cos_sim = sim(embeddings.unsqueeze(1),
             embeddings.unsqueeze(0))
             
print(f"Distance: {cos_sim[0,1].detach().item()}")

Benchmark

Model results on SentEval Benchmark:

+-------+-------+-------+-------+-------+--------------+-----------------+--------+                                               
| STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness | S.Avg. |                                               
+-------+-------+-------+-------+-------+--------------+-----------------+--------+                                               
| 71.71 | 83.09 | 75.46 | 83.13 | 80.22 |    79.70     |      73.62      | 78.13  |                                               
+-------+-------+-------+-------+-------+--------------+-----------------+--------+  

Citations

If you use this code in your research or want to refer to our work, please cite:

@article{Klein2022miCSEMI,
  title={miCSE: Mutual Information Contrastive Learning for Low-shot Sentence Embeddings},
  author={Tassilo Klein and Moin Nabi},
  journal={ArXiv},
  year={2022},
  volume={abs/2211.04928}
}

Authors: