cointegrated's picture
Update README.md
2fa46f3
|
raw
history blame
2.57 kB
metadata
language:
  - ru
tags:
  - sentence-similarity
  - text-classification
datasets:
  - merionum/ru_paraphraser
  - RuPAWS

This is a cross-encoder model trained to predict semantic equivalence of two Russian sentences.

It classifies text pairs as paraphrases (class 1) or non-paraphrases (class 0). Its scores can be used as a metric of content preservation for paraphrasing or text style transfer.

It is a sberbank-ai/ruRoberta-large model fine-tuned on a union of 3 datasets:

  1. RuPAWS: https://github.com/ivkrotova/rupaws_dataset based on Quora and QQP;
  2. ru_paraphraser: https://huggingface.co/merionum/ru_paraphraser;
  3. Results of the manual check of content preservation for the RUSSE-2022 text detoxification dataset collection (content_5.tsv).

The task was formulated as binary classification: whether the two sentences have the same meaning (1) or different (0).

The table shows the training dataset size after duplication (joining text1 + text2 and text2 + text1 pairs):

source \ label 0 1
detox 1412 3843
paraphraser 5539 1688
rupaws_qqp 1112 792
rupaws_wiki 3526 2166

The model was trained with Adam optimizer and the following hyperparameters:

learning_rate = 1e-5
batch_size = 8
gradient_accumulation_steps = 4
n_epochs = 3
max_grad_norm = 1.0

After training, the model had the following ROC AUC scores on the test sets:

set ROC AUC
detox 0.857112
paraphraser 0.858465
rupaws_qqp 0.859195
rupaws_wiki 0.906121

Example usage:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained('SkolkovoInstitute/ruRoberta-large-paraphrase-v1')
tokenizer = AutoTokenizer.from_pretrained('SkolkovoInstitute/ruRoberta-large-paraphrase-v1')

def get_similarity(text1, text2):
    """ Predict the probability that two Russian sentences are paraphrases of each other. """
    with torch.inference_mode():
        batch = tokenizer(
            text1, text2, 
            truncation=True, max_length=model.config.max_position_embeddings, return_tensors='pt',
        ).to(model.device)
        proba = torch.softmax(model(**batch).logits, -1)
    return proba[0][1].item()

print(get_similarity('Я тебя люблю', 'Ты мне нравишься'))  # 0.9798
print(get_similarity('Я тебя люблю', 'Я тебя ненавижу'))   # 0.0008