sonoisa commited on
Commit
c5fb5a9
·
1 Parent(s): c840e13

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +74 -0
README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ja
3
+ license: cc-by-sa-4.0
4
+ tags:
5
+ - sentence-transformers
6
+ - sentence-bert
7
+ - feature-extraction
8
+ - sentence-similarity
9
+ ---
10
+
11
+ This is a Japanese+English sentence-BERT model.
12
+
13
+ 日本語+英語用Sentence-BERTモデルです。
14
+
15
+ 事前学習済みモデルとして[cl-tohoku/bert-base-japanese-whole-word-masking](https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking)を利用しました。
16
+ 推論の実行にはfugashiとipadicが必要です(pip install fugashi ipadic)。
17
+
18
+
19
+ # 日本語のみバージョンの解説
20
+
21
+ https://qiita.com/sonoisa/items/1df94d0a98cd4f209051
22
+
23
+ モデル名を"sonoisa/sentence-bert-base-ja-en-mean-tokens"に書き換えれば、本モデルを利用した挙動になります。
24
+
25
+
26
+ # 使い方
27
+
28
+ ```python
29
+ from transformers import BertJapaneseTokenizer, BertModel
30
+ import torch
31
+
32
+
33
+ class SentenceBertJapanese:
34
+ def __init__(self, model_name_or_path, device=None):
35
+ self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
36
+ self.model = BertModel.from_pretrained(model_name_or_path)
37
+ self.model.eval()
38
+
39
+ if device is None:
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ self.device = torch.device(device)
42
+ self.model.to(device)
43
+
44
+ def _mean_pooling(self, model_output, attention_mask):
45
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
46
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
47
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
48
+
49
+ @torch.no_grad()
50
+ def encode(self, sentences, batch_size=8):
51
+ all_embeddings = []
52
+ iterator = range(0, len(sentences), batch_size)
53
+ for batch_idx in iterator:
54
+ batch = sentences[batch_idx:batch_idx + batch_size]
55
+
56
+ encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
57
+ truncation=True, return_tensors="pt").to(self.device)
58
+ model_output = self.model(**encoded_input)
59
+ sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')
60
+
61
+ all_embeddings.extend(sentence_embeddings)
62
+
63
+ # return torch.stack(all_embeddings).numpy()
64
+ return torch.stack(all_embeddings)
65
+
66
+
67
+ MODEL_NAME = "sonoisa/sentence-bert-base-ja-en-mean-tokens"
68
+ model = SentenceBertJapanese(MODEL_NAME)
69
+
70
+ sentences = ["暴走したAI", "暴走した人工知能"]
71
+ sentence_embeddings = model.encode(sentences, batch_size=8)
72
+
73
+ print("Sentence embeddings:", sentence_embeddings)
74
+ ```