|
--- |
|
license: cc-by-sa-4.0 |
|
datasets: |
|
- unicamp-dl/mmarco |
|
language: |
|
- ja |
|
--- |
|
|
|
We initialize SPLADE-japanese from [tohoku-nlp/bert-base-japanese-v2](https://huggingface.co/tohoku-nlp/bert-base-japanese-v2). |
|
This model is trained on [mMARCO](https://github.com/unicamp-dl/mMARCO) Japanese dataset. |
|
|
|
```python |
|
from transformers import AutoModelForMaskedLM,AutoTokenizer |
|
import torch |
|
import numpy as np |
|
|
|
model = AutoModelForMaskedLM.from_pretrained("aken12/splade-japanese") |
|
tokenizer = AutoTokenizer.from_pretrained("aken12/splade-japanese") |
|
|
|
query = "筑波大学では何の研究が行われているか?" |
|
|
|
def encode_query(query, tokenizer, model): |
|
encoded_input = tokenizer(query, return_tensors="pt") |
|
with torch.no_grad(): |
|
output = model(**encoded_input, return_dict=True).logits |
|
aggregated_output, _ = torch.max(torch.log(1 + torch.relu(output)) * encoded_input['attention_mask'].unsqueeze(-1), dim=1) |
|
return aggregated_output |
|
|
|
def get_topk_tokens(reps, vocab_dict, topk): |
|
topk_values, topk_indices = torch.topk(reps, topk, dim=1) |
|
values = np.rint(topk_values.numpy() * 100).astype(int) |
|
dict_splade = {vocab_dict[id_token.item()]: int(value_token) for id_token, value_token in zip(topk_indices[0], values[0]) if value_token > 0} |
|
return dict_splade |
|
|
|
|
|
vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()} |
|
topk = len(vocab_dict) // 1000 |
|
|
|
|
|
model_output = encode_query(query, tokenizer, model) |
|
|
|
|
|
dict_splade = get_topk_tokens(model_output, vocab_dict, topk) |
|
|
|
|
|
for token, value in dict_splade.items(): |
|
print(token, value) |
|
``` |
|
|
|
## output |
|
``` |
|
に 250 |
|
が 248 |
|
は 247 |
|
の 247 |
|
、 244 |
|
と 240 |
|
を 239 |
|
。 239 |
|
も 238 |
|
で 237 |
|
から 221 |
|
や 219 |
|
な 206 |
|
筑波 204 |
|
( 204 |
|
・ 202 |
|
て 197 |
|
へ 191 |
|
にて 189 |
|
など 188 |
|
) 186 |
|
まで 184 |
|
た 182 |
|
この 171 |
|
- 170 |
|
「 170 |
|
より 166 |
|
その 165 |
|
: 163 |
|
」 161 |
|
``` |
|
|