splade-japanese / README.md
aken12's picture
Update README.md
2de94d1 verified
|
raw
history blame
1.57 kB
metadata
license: cc-by-sa-4.0
datasets:
  - unicamp-dl/mmarco
language:
  - ja

We initialize SPLADE-japanese from tohoku-nlp/bert-base-japanese-v2 and trained This model is trained on mMARCO Japanese dataset.

from transformers import AutoModelForMaskedLM,AutoTokenizer
import torch
import numpy as np

model = AutoModelForMaskedLM.from_pretrained("aken12/splade-japanese") 
tokenizer = AutoTokenizer.from_pretrained("aken12/splade-japanese")

query = "私は筑波大学の学生です"

def encode_query(query, tokenizer, model):
    encoded_input = tokenizer(query, return_tensors="pt")
    with torch.no_grad():
        output = model(**encoded_input, return_dict=True).logits
    aggregated_output, _ = torch.max(torch.log(1 + torch.relu(output)) * encoded_input['attention_mask'].unsqueeze(-1), dim=1)
    return aggregated_output

def get_topk_tokens(reps, vocab_dict, topk):
    topk_values, topk_indices = torch.topk(reps, topk, dim=1)
    values = np.rint(topk_values.numpy() * 100).astype(int)
    dict_splade = {vocab_dict[id_token.item()]: int(value_token) for id_token, value_token in zip(topk_indices[0], values[0]) if value_token > 0}
    return dict_splade


vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()}
topk = len(vocab_dict) // 1000 


model_output = encode_query(query, tokenizer, model)


dict_splade = get_topk_tokens(model_output, vocab_dict, topk)


for token, value in dict_splade.items():
    print(token, value)