Katsumata420
commited on
Commit
·
bff929a
1
Parent(s):
825d6e6
Upload scripts
Browse files- download_wikipedia_bert.py +17 -0
- sample_mlm.py +21 -0
download_wikipedia_bert.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertJapaneseTokenizer
|
2 |
+
from transformers import BertConfig
|
3 |
+
from transformers import BertForPreTraining
|
4 |
+
|
5 |
+
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
|
6 |
+
config = BertConfig().from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
|
7 |
+
|
8 |
+
tokenizer.save_pretrained('models/1-6_layer-wise')
|
9 |
+
config.save_pretrained('models/1-6_layer-wise')
|
10 |
+
|
11 |
+
tokenizer.save_pretrained('models/tapt512_60K')
|
12 |
+
config.save_pretrained('models/tapt512_60K')
|
13 |
+
|
14 |
+
tokenizer.save_pretrained('models/dapt128-tapt512')
|
15 |
+
config.save_pretrained('models/dapt128-tapt512')
|
16 |
+
|
17 |
+
# model = BertForPreTraining(config).from_pretrained('models/dapt128-tapt512')
|
sample_mlm.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertJapaneseTokenizer
|
2 |
+
from transformers import BertConfig
|
3 |
+
from transformers import BertForMaskedLM
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
inputs = ['[MASK]もそう思います', '[MASK]なんというかその', 'これは[MASK]私が子供の頃の話なんですけど']
|
7 |
+
|
8 |
+
model_name_list = ['models/1-6_layer-wise', 'models/tapt512_60K', 'models/dapt128-tapt512']
|
9 |
+
|
10 |
+
|
11 |
+
for input_, model_name in zip(inputs, model_name_list):
|
12 |
+
|
13 |
+
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
|
14 |
+
config = BertConfig.from_pretrained(model_name)
|
15 |
+
model = BertForMaskedLM.from_pretrained(model_name)
|
16 |
+
|
17 |
+
print('model name:',model_name)
|
18 |
+
print('input:',input_)
|
19 |
+
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer, config=config)
|
20 |
+
print('output:',fill_mask(input_))
|
21 |
+
print()
|