|
--- |
|
language: |
|
- ja |
|
library_name: sentence-transformers |
|
tags: |
|
- sentence-transformers |
|
- sentence-similarity |
|
- feature-extraction |
|
metrics: |
|
widget: [] |
|
pipeline_tag: sentence-similarity |
|
license: apache-2.0 |
|
datasets: |
|
- hpprc/emb |
|
- hpprc/mqa-ja |
|
- google-research-datasets/paws-x |
|
--- |
|
# RoSEtta |
|
|
|
RoSEtta (**Ro**Former-based **S**entence **E**ncoder **t**hrough Dis**t**ill**a**tion) is a general Japanese text embedding model, excelling in retrieval tasks. It has a maximum sequence length of 1024, allowing for input of long sentences. It can run on a CPU and is designed to measure semantic similarity between sentences, as well as to function as a retrieval system for searching passages based on queries. |
|
|
|
Key features: |
|
|
|
- Use RoPE (Rotary Position Embedding) |
|
- Maximum sequence length of 1024 tokens |
|
- Distilled from large sentence embedding models |
|
- Specialized for retrieval tasks |
|
|
|
During inference, the prefix "query: " or "passage: " is required. Please check the Usage section for details. |
|
|
|
## Model Description |
|
|
|
This model is based on RoFormer architecture. After pre-training using MLM loss, weakly supervised learning was performed. Additionally, further training was conducted through distillation using several large embedding models and multi-stage contrastive learning (like [GLuCoSE v2](https://huggingface.co/pkshatech/GLuCoSE-base-ja-v2)). |
|
|
|
- **Maximum Sequence Length:** 1024 tokens |
|
- **Output Dimensionality:** 768 tokens |
|
- **Similarity Function:** Cosine Similarity |
|
|
|
## Usage |
|
|
|
### Direct Usage (Sentence Transformers) |
|
|
|
You can perform inference using SentenceTransformer with the following code: |
|
|
|
```python |
|
from sentence_transformers import SentenceTransformer |
|
import torch.nn.functional as F |
|
|
|
# Download from the 🤗 Hub |
|
# The argument "trust_remote_code=True" is required to load the model |
|
model = SentenceTransformer("pkshatech/RoSEtta-base-ja",trust_remote_code=True) |
|
|
|
# Each input text should start with "query: " or "passage: ". |
|
# For tasks other than retrieval, you can simply use the "query: " prefix. |
|
sentences = [ |
|
'query: PKSHAはどんな会社ですか?', |
|
'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。', |
|
'query: 日本で一番高い山は?', |
|
'passage: 富士山(ふじさん)は、標高3776.12 m、日本最高峰(剣ヶ峰)の独立峰で、その優美な風貌は日本国外でも日本の象徴として広く知られている。', |
|
] |
|
embeddings = model.encode(sentences,convert_to_tensor=True) |
|
print(embeddings.shape) |
|
# [4, 768] |
|
|
|
# Get the similarity scores for the embeddings |
|
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2) |
|
print(similarities) |
|
# [[1.0000, 0.5910, 0.4332, 0.5421], |
|
# [0.5910, 1.0000, 0.4977, 0.6969], |
|
# [0.4332, 0.4977, 1.0000, 0.7475], |
|
# [0.5421, 0.6969, 0.7475, 1.0000]] |
|
|
|
``` |
|
|
|
### Direct Usage (Transformers) |
|
|
|
You can perform inference using Transformers with the following code: |
|
|
|
```python |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
def mean_pooling(last_hidden_states: Tensor,attention_mask: Tensor) -> Tensor: |
|
emb = last_hidden_states * attention_mask.unsqueeze(-1) |
|
emb = emb.sum(dim=1) / attention_mask.sum(dim=1).unsqueeze(-1) |
|
return emb |
|
|
|
# Download from the 🤗 Hub |
|
tokenizer = AutoTokenizer.from_pretrained("pkshatech/RoSEtta-base-ja") |
|
# The argument "trust_remote_code=True" is required to load the model |
|
model = AutoModel.from_pretrained("pkshatech/RoSEtta-base-ja",trust_remote_code=True) |
|
|
|
# Each input text should start with "query: " or "passage: ". |
|
# For tasks other than retrieval, you can simply use the "query: " prefix. |
|
sentences = [ |
|
'query: PKSHAはどんな会社ですか?', |
|
'passage: 研究開発したアルゴリズムを、多くの企業のソフトウエア・オペレーションに導入しています。', |
|
'query: 日本で一番高い山は?', |
|
'passage: 富士山(ふじさん)は、標高3776.12 m、日本最高峰(剣ヶ峰)の独立峰で、その優美な風貌は日本国外でも日本の象徴として広く知られている。', |
|
] |
|
|
|
# Tokenize the input texts |
|
batch_dict = tokenizer(sentences, max_length=1024, padding=True, truncation=True, return_tensors='pt') |
|
|
|
outputs = model(**batch_dict) |
|
embeddings = mean_pooling(outputs.last_hidden_state, batch_dict['attention_mask']) |
|
print(embeddings.shape) |
|
# [4, 768] |
|
|
|
# Get the similarity scores for the embeddings |
|
similarities = F.cosine_similarity(embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2) |
|
print(similarities) |
|
# [[1.0000, 0.5910, 0.4332, 0.5421], |
|
# [0.5910, 1.0000, 0.4977, 0.6969], |
|
# [0.4332, 0.4977, 1.0000, 0.7475], |
|
# [0.5421, 0.6969, 0.7475, 1.0000]] |
|
|
|
``` |
|
|
|
## Training Details |
|
|
|
The fine-tuning of RoSEtta is carried out through the following steps: |
|
|
|
**Step 1: Pre-training** |
|
|
|
- The model is pre-trained based on RoFormer architecture. |
|
- Training data: [Japanese Wikipedia](https://dumps.wikimedia.org/other/cirrussearch/) and [cc100](https://data.statmt.org/cc-100/). |
|
|
|
**Step 2: Weakly supervised learning** |
|
|
|
- Training data: [MQA](https://huggingface.co/datasets/clips/mqa) and [mc4](https://huggingface.co/datasets/legacy-datasets/mc4). |
|
|
|
**Step 3: Ensemble distillation** |
|
|
|
- The embedded representation was distilled using [E5-mistral](https://huggingface.co/intfloat/e5-mistral-7b-instruct), [gte-Qwen2](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct), and [mE5-large](https://huggingface.co/intfloat/multilingual-e5-large) as teacher models. |
|
|
|
**Step 4: Contrastive learning** |
|
|
|
- Triplets were created from [JSNLI](https://nlp.ist.i.kyoto-u.ac.jp/?%E6%97%A5%E6%9C%AC%E8%AA%9ESNLI%28JSNLI%29%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88), [MNLI](https://huggingface.co/datasets/MoritzLaurer/multilingual-NLI-26lang-2mil7), [PAWS-X](https://huggingface.co/datasets/paws-x), [JSeM](https://github.com/DaisukeBekki/JSeM) and [Mr.TyDi](https://huggingface.co/datasets/castorini/mr-tydi) and used for training. |
|
- This training aimed to improve the overall performance as a sentence embedding model. |
|
|
|
**Step 5: Search-specific contrastive learning** |
|
|
|
- In order to make the model more robust to the retrieval task, additional two-stage training with QA and retrieval task was conducted. |
|
- In the first stage, the synthetic dataset [auto-wiki-qa](https://huggingface.co/datasets/cl-nagoya/auto-wiki-qa) was used for training, |
|
while in the second stage, [JQaRA](https://huggingface.co/datasets/hotchpotch/JQaRA), [MQA](https://huggingface.co/datasets/hpprc/mqa-ja), [Japanese Wikipedia Human Retrieval, Mr.TyDi,MIRACL, Quiz Works and Quiz No Mor](https://huggingface.co/datasets/hpprc/emb)i were used. |
|
|
|
## Benchmarks |
|
|
|
### Retrieval |
|
|
|
Evaluated with [MIRACL-ja](https://huggingface.co/datasets/miracl/miracl), [JQARA](https://huggingface.co/datasets/hotchpotch/JQaRA) , [JaCWIR](https://huggingface.co/datasets/hotchpotch/JaCWIR) and [MLDR-ja](https://huggingface.co/datasets/Shitao/MLDR). |
|
|
|
| Model | Size | MIRACL<br>Recall@5 | JQaRA<br>nDCG@10 | JaCWIR<br>MAP@10 | MLDR<br>nDCG@10 | |
|
| :---: | :---: | :---: | :---: | :---: | :---: | |
|
| [intfloat/multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large) | 0.6B | 89.2 | 55.4 | **87.6** | 29.8 | |
|
| [cl-nagoya/ruri-large](https://huggingface.co/cl-nagoya/ruri-large) | 0.3B | 78.7 | 62.4 | 85.0 | **37.5** | |
|
| | | | | | | |
|
| [intfloat/multilingual-e5-base](https://huggingface.co/intfloat/multilingual-e5-base) | 0.3B | **84.2** | 47.2 | **85.3** | 25.4 | |
|
| [cl-nagoya/ruri-base](https://huggingface.co/cl-nagoya/ruri-base) | 0.1B | 74.3 | **58.1** | 84.6 | **35.3** | |
|
| [pkshatech/GLuCoSE-base-ja](https://huggingface.co/pkshatech/GLuCoSE-base-ja) | 0.1B | 53.3 | 30.8 | 68.6 | 25.2 | |
|
| RoSEtta | 0.2B | 79.3 | 57.7 | 83.8 | 32.3 | |
|
|
|
Note: Results for OpenAI small embeddings in JQARA and JaCWIR are quoted from the [JQARA](https://huggingface.co/datasets/hotchpotch/JQaRA) and [JaCWIR](https://huggingface.co/datasets/hotchpotch/JaCWIR). |
|
|
|
### JMTEB |
|
|
|
Evaluated with [JMTEB](https://github.com/sbintuitions/JMTEB). |
|
|
|
The average score is macro-average. |
|
|
|
| Model | Size | Avg. | Retrieval | STS | Classification | Reranking | Clustering | PairClassification | |
|
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | |
|
| OpenAI/text-embedding-3-small | - | 69.18 | 66.39 | 79.46 | 73.06 | 92.92 | 51.06 | 62.27 | |
|
| OpenAI/text-embedding-3-large | - | 74.05 | 74.48 | 82.52 | 77.58 | 93.58 | 53.32 | 62.35 | |
|
| | | | | | | | | | |
|
| [intfloat/multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large) | 0.6B | 70.90 | 70.98 | 79.70 | 72.89 | 92.96 | 51.24 | 62.15 | |
|
| [cl-nagoya/ruri-large](https://huggingface.co/cl-nagoya/ruri-large) | 0.3B | 73.31 | 73.02 | 83.13 | 77.43 | 92.99 | 51.82 | 62.29 | |
|
| | | | | | | | | | |
|
| [intfloat/multilingual-e5-base](https://huggingface.co/intfloat/multilingual-e5-base) | 0.3B | 68.61 | 68.21 | 79.84 | 69.30 | **92.85** | 48.26 | 62.26 | |
|
| [cl-nagoya/ruri-base](https://huggingface.co/cl-nagoya/ruri-base) | 0.1B | 71.91 | 69.82 | **82.87** | 75.58 | 92.91 | **54.16** | 62.38 | |
|
| [pkshatech/GLuCoSE-base-ja](https://huggingface.co/pkshatech/GLuCoSE-base-ja) | 0.1B | 67.29 | 59.02 | 78.71 | **76.82** | 91.90 | 49.78 | **66.39** | |
|
| RoSEtta | 0.2B | **72.45** | **73.21** | 81.39 | 72.41 | 92.69 | 53.23 | 61.74 | |
|
|
|
## Authors |
|
|
|
Chihiro Yano, Mocho Go, Hideyuki Tachibana, Hiroto Takegawa, Yotaro Watanabe |
|
|
|
## License |
|
|
|
This model is published under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0). |