ai-forever commited on
Commit
3248018
·
verified ·
1 Parent(s): bbe14ec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -40
README.md CHANGED
@@ -4,55 +4,111 @@ language:
4
  - ru
5
  - en
6
  tags:
7
- - PyTorch
8
- - Transformers
9
  ---
10
 
11
- # ru-en RoBERTa large model for Sentence Embeddings in Russian and English.
12
- The model is described [in this article](<link of our arxiv>)
13
- Russian MTEB [metrics](<lin of our ruMTEB>)
14
-
15
- For better quality, use cls token embeddings.
16
- Also, use next prefixes for tasks:
17
- - For assimethric retrieval tasks like search/QuestAnsw: "search_query: "/"search_document: ".
18
- - NLI, NLU and paraphrasing tasks: "classification: ".
19
- - Title body/abstract and clasterization: "clustering: ".
20
- ## Usage (HuggingFace Models Repository)
21
- You can use the model directly from the model repository to compute sentence embeddings:
 
 
 
 
 
 
 
 
 
 
22
  ```python
23
- from transformers import AutoTokenizer, AutoModel
24
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- #You might to use two variants of mode for embeddings creation:
27
- #CLS token embs or MEAN Pooling.
28
- #You can choose embs pooling with best quality for your downstream tasks.
29
-
30
- #Mean Pooling example - Take attention mask into account for correct averaging
31
- def mean_pooling(model_output, attention_mask):
32
- token_embeddings = model_output[0] #First element of model_output contains all token embeddings
33
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
34
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
35
- sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
36
- return sum_embeddings / sum_mask
37
-
38
- #Sentences we want sentence embeddings for
39
- sentences = ['Привет! Как твои дела?',
40
- 'А правда, что 42 твое любимое число?']
41
- #Load AutoModel from huggingface model repository
42
  tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa")
43
  model = AutoModel.from_pretrained("ai-forever/ru-en-RoSBERTa")
44
- #Tokenize sentences
45
- encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=512, return_tensors='pt')
46
 
47
- #Compute token embeddings
 
48
  with torch.no_grad():
49
- model_output = model(**encoded_input)
50
 
51
- #In this case, mean pooling
52
- sentence_mean_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- #In this case, cls "pooling"
55
- last_hidden_states = model_output[0]
56
- sentence_cls_embeddings = last_hidden_states[:,0]
57
 
58
- ```
 
4
  - ru
5
  - en
6
  tags:
7
+ - transformers
8
+ - sentence-transformers
9
  ---
10
 
11
+ # Model Card for ru-en-RoSBERTa
12
+
13
+ The ru-en-RoSBERTa is a general text embedding model for Russian. The model is based on [ruRoBERTa](https://huggingface.co/ai-forever/ruRoberta-large) and fine-tuned with ~4M pairs of supervised, synthetic and unsupervised data in Russian and English. Tokenizer supports some English tokens from [RoBERTa](https://huggingface.co/FacebookAI/roberta-large) tokenizer.
14
+
15
+ For more model details please refer to our [article](arxiv).
16
+
17
+ ## Usage
18
+
19
+ The model can be used as is with prefixes. It is recommended to use CLS pooling. The choice of prefix and pooling depends on the task.
20
+
21
+ We use the following basic rules to choose a prefix:
22
+ - `"search_query: "` and `"search_document: "` prefixes are for answer or relevant paragraph retrieval
23
+ - `"classification: "` prefix is for symmetric paraphrasing related tasks (STS, NLI, Bitext Mining)
24
+ - `"clustering: "` prefix is for any tasks that rely on thematic features (topic classification, title-body retrieval)
25
+
26
+ To better tailor the model to your needs, you can fine-tune it with relevant high-quality Russian and English datasets.
27
+
28
+ Below are examples of texts encoding using the Transformers and SentenceTransformers libraries.
29
+
30
+ ### Transformers
31
+
32
  ```python
 
33
  import torch
34
+ import torch.nn.functional as F
35
+ from transformers import AutoTokenizer, AutoModel
36
+
37
+
38
+ def pool(hidden_state, mask, pooling_method="cls"):
39
+ if pooling_method == "mean":
40
+ s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
41
+ d = mask.sum(axis=1, keepdim=True).float()
42
+ return s / d
43
+ elif pooling_method == "cls":
44
+ return hidden_state[:, 0]
45
+
46
+ inputs = [
47
+ #
48
+ "classification: Он нам и <unk> не нужон ваш Интернет!",
49
+ "clustering: В Ярославской области разрешили работу бань, но без посетителей",
50
+ "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
51
+
52
+ #
53
+ "classification: What a time to be alive!",
54
+ "clustering: Ярославским баням разрешили работать без посетителей",
55
+ "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.",
56
+ ]
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa")
59
  model = AutoModel.from_pretrained("ai-forever/ru-en-RoSBERTa")
 
 
60
 
61
+ tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")
62
+
63
  with torch.no_grad():
64
+ outputs = model(**tokenized_inputs)
65
 
66
+ embeddings = pool(
67
+ outputs.last_hidden_state,
68
+ tokenized_inputs["attention_mask"],
69
+ pooling_method="cls" # or try "mean"
70
+ )
71
+
72
+ embeddings = F.normalize(embeddings, p=2, dim=1)
73
+
74
+ sim_scores = embeddings[:3] @ embeddings[3:].T
75
+ print(sim_scores.diag().tolist())
76
+ # [0.4796873927116394, 0.9409002065658569, 0.7761015892028809]
77
+ ```
78
+
79
+ ### SentenceTransformers
80
+
81
+ ```python
82
+ from sentence_transformers import SentenceTransformer
83
+
84
+
85
+ inputs = [
86
+ #
87
+ "classification: Он нам и <unk> не нужон ваш Интернет!",
88
+ "clustering: В Ярославской области разрешили работу бань, но без посетителей",
89
+ "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
90
+
91
+ #
92
+ "classification: What a time to be alive!",
93
+ "clustering: Ярославским баням разрешили работать без посетителей",
94
+ "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.",
95
+ ]
96
+
97
+ # loads model with CLS pooling
98
+ model = SentenceTransformer("ai-forever/ru-en-RoSBERTa")
99
+
100
+ # embeddings are normalized by default
101
+ embeddings = model.encode(inputs, convert_to_tensor=True)
102
+
103
+ sim_scores = embeddings[:3] @ embeddings[3:].T
104
+ print(sim_scores.diag().tolist())
105
+ # [0.47968706488609314, 0.940900444984436, 0.7761018872261047]
106
+ ```
107
+
108
+ ## Citation
109
+
110
+ TODO
111
 
112
+ ## Limitations
 
 
113
 
114
+ The model is designed to process texts in Russian, the quality in English is unknown. Maximum input text length is limited to 512 tokens.