aken12 commited on
Commit
2de94d1
1 Parent(s): 934c239

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +45 -0
README.md CHANGED
@@ -1,3 +1,48 @@
1
  ---
2
  license: cc-by-sa-4.0
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-sa-4.0
3
+ datasets:
4
+ - unicamp-dl/mmarco
5
+ language:
6
+ - ja
7
  ---
8
+
9
+ We initialize SPLADE-japanese from [tohoku-nlp/bert-base-japanese-v2](https://huggingface.co/tohoku-nlp/bert-base-japanese-v2) and trained
10
+ This model is trained on [mMARCO](https://github.com/unicamp-dl/mMARCO) Japanese dataset.
11
+
12
+ ```python
13
+ from transformers import AutoModelForMaskedLM,AutoTokenizer
14
+ import torch
15
+ import numpy as np
16
+
17
+ model = AutoModelForMaskedLM.from_pretrained("aken12/splade-japanese")
18
+ tokenizer = AutoTokenizer.from_pretrained("aken12/splade-japanese")
19
+
20
+ query = "私は筑波大学の学生です"
21
+
22
+ def encode_query(query, tokenizer, model):
23
+ encoded_input = tokenizer(query, return_tensors="pt")
24
+ with torch.no_grad():
25
+ output = model(**encoded_input, return_dict=True).logits
26
+ aggregated_output, _ = torch.max(torch.log(1 + torch.relu(output)) * encoded_input['attention_mask'].unsqueeze(-1), dim=1)
27
+ return aggregated_output
28
+
29
+ def get_topk_tokens(reps, vocab_dict, topk):
30
+ topk_values, topk_indices = torch.topk(reps, topk, dim=1)
31
+ values = np.rint(topk_values.numpy() * 100).astype(int)
32
+ dict_splade = {vocab_dict[id_token.item()]: int(value_token) for id_token, value_token in zip(topk_indices[0], values[0]) if value_token > 0}
33
+ return dict_splade
34
+
35
+
36
+ vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()}
37
+ topk = len(vocab_dict) // 1000
38
+
39
+
40
+ model_output = encode_query(query, tokenizer, model)
41
+
42
+
43
+ dict_splade = get_topk_tokens(model_output, vocab_dict, topk)
44
+
45
+
46
+ for token, value in dict_splade.items():
47
+ print(token, value)
48
+ ```