高性能な日本語 SPLADE (Sparse Lexical and Expansion Model) モデルです。テキストからスパースベクトルへの変換デモで、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。
また、モデルの学習にはYAST - Yet Another SPLADE or Sparse Trainerを使っています。
利用方法
YASEM (Yet Another Splade|Sparse Embedder)
pip install yasem
from yasem import SpladeEmbedder
model_name = "hotchpotch/japanese-splade-base-v1"
embedder = SpladeEmbedder(model_name)
sentences = [
"車の燃費を向上させる方法は?",
"急発進や急ブレーキを避け、一定速度で走行することで燃費が向上します。",
"車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。",
]
embeddings = embedder.encode(sentences)
similarity = embedder.similarity(embeddings, embeddings)
print(similarity)
# [[21.49299249 10.48868281 6.25582337]
# [10.48868281 12.90587398 3.19429791]
# [ 6.25582337 3.19429791 12.89678271]]
token_values = embedder.get_token_values(embeddings[0])
print(token_values)
#{
# '車': 2.1796875,
# '燃費': 2.146484375,
# '向上': 1.7353515625,
# '方法': 1.55859375,
# '燃料': 1.3291015625,
# '効果': 1.1376953125,
# '良い': 0.873046875,
# '改善': 0.8466796875,
# 'アップ': 0.833984375,
# 'いう': 0.70849609375,
# '理由': 0.64453125,
# ...
transformers
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def splade_max_pooling(logits, attention_mask):
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
return max_val
tokens = tokenizer(
sentences, return_tensors="pt", padding=True, truncation=True, max_length=512
)
tokens = {k: v.to(model.device) for k, v in tokens.items()}
with torch.no_grad():
outputs = model(**tokens)
embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"])
similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0)
print(similarity)
# tensor([[21.4943, 10.4816, 6.2540],
# [10.4816, 12.9024, 3.1939],
# [ 6.2540, 3.1939, 12.8919]])
ベンチマークスコア
retrieval (JMTEB)
JMTEB の評価結果です。japanese-splade-base-v1 は JMTEB をスパースベクトルで評価できるように変更したコードでの評価となっています。 なお、japanese-splade-base-v1 は jaqket, mrtydi のドメインを学習(testのデータ以外)しています。
model_name | Avg. | jagovfaqs_22k | jaqket | mrtydi | nlp_journal_abs_intro | nlp_journal_title_abs | nlp_journal_title_intro |
---|---|---|---|---|---|---|---|
japanese-splade-base-v1 | 0.7465 | 0.6499 | 0.6992 | 0.4365 | 0.8967 | 0.9766 | 0.8203 |
text-embedding-3-large | 0.7448 | 0.7241 | 0.4821 | 0.3488 | 0.9933 | 0.9655 | 0.9547 |
GLuCoSE-base-ja-v2 | 0.7336 | 0.6979 | 0.6729 | 0.4186 | 0.9029 | 0.9511 | 0.7580 |
multilingual-e5-large | 0.7098 | 0.7030 | 0.5878 | 0.4363 | 0.8600 | 0.9470 | 0.7248 |
multilingual-e5-small | 0.6727 | 0.6411 | 0.4997 | 0.3605 | 0.8521 | 0.9526 | 0.7299 |
ruri-large | 0.7302 | 0.7668 | 0.6174 | 0.3803 | 0.8712 | 0.9658 | 0.7797 |
reranking
JaCWIR
なお、japanese-splade-base-v1 は JaCWIR のドメインを学習していません。
model_names | map@10 | hit_rate@10 |
---|---|---|
japanese-splade-base-v1 | 0.9122 | 0.9854 |
text-embedding-3-small | 0.8168 | 0.9506 |
GLuCoSE-base-ja-v2 | 0.8567 | 0.9676 |
bge-m3+dense | 0.8642 | 0.9684 |
multilingual-e5-large | 0.8759 | 0.9726 |
multilingual-e5-small | 0.869 | 0.97 |
ruri-large | 0.8291 | 0.9594 |
JQaRA
なお、japanese-splade-base-v1 は JQaRA のドメイン(test以外)を学習したものとなっています。
model_names | ndcg@10 | mrr@10 |
---|---|---|
japanese-splade-base-v1 | 0.6441 | 0.8616 |
text-embedding-3-small | 0.3881 | 0.6107 |
bge-m3+dense | 0.539 | 0.7854 |
multilingual-e5-large | 0.554 | 0.7988 |
multilingual-e5-small | 0.4917 | 0.7291 |
GLuCoSE-base-ja-v2 | 0.606 | 0.8359 |
ruri-large | 0.6287 | 0.8418 |
学習元データセット
hpprc/emb から、auto-wiki-qa, mmarco, jsquad jaquad, auto-wiki-qa-nemotron, quiz-works quiz-no-mori, miracl, jqara mr-tydi, baobab-wiki-retrieval, mkqa データセットを利用しています。 また英語データセットとして、MS Marcoを利用しています。
注意事項
text-embeddings-inference で動かす場合、hotchpotch/japanese-splade-base-v1-dummy-fast-tokenizer-for-teiをご利用ください。
- Downloads last month
- 840
Model tree for hotchpotch/japanese-splade-base-v1
Base model
tohoku-nlp/bert-base-japanese-v3