File size: 3,063 Bytes
ece09c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from typing import Dict, List
from transformers import BertJapaneseTokenizer, BertForTokenClassification, pipeline
from transformers.pipelines.token_classification import TokenClassificationPipeline
class NaturalLanguageProcessing:
"""
固有表現を抽出するクラス
model_dirにある言語モデルを使って固有表現抽出パイプラインを作成する
モデルに固有表現を抽出させるメソッドを持つ
Attributes
----------
_nlp : TokenClassificationPipeline
固有表現抽出パイプライン
"""
def __init__(self, model_dir: str):
"""
コンストラクタ
_nlpを作成する
Parameters
----------
model_dir : str
使用する言語モデルのディレクトリ
"""
self._nlp = NaturalLanguageProcessing._create(model_dir)
@staticmethod
def _create(model_dir: str) -> TokenClassificationPipeline:
"""
パイプラインの作成
Parameters
----------
model_dir : str
使用する言語モデルのディレクトリ
Returns
-------
TokenClassificationPipeline
固有表現抽出パイプライン
"""
tokenizer = BertJapaneseTokenizer.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir)
nlp = pipeline(
'token-classification',
model=model,
tokenizer=tokenizer,
aggregation_strategy='simple'
)
return nlp
def classify(self, input: str) -> Dict[str, List[str]]:
"""
固有表現の抽出
Parameters
----------
input : str
固有表現抽出対象
Returns
-------
Dict[str, List[str]]
抽出結果の辞書
キーが分類ラベル、バリューがそのラベルの文字列のリスト
"""
prediction_results:List[Dict[str, str | float | None]] = self._nlp(input)
classified_words = {}
for predict_result in prediction_results:
label = predict_result['entity_group']
word = predict_result['word']
if label not in classified_words:
classified_words[label] = []
classified_words[label].append(word.replace(' ', ''))
return classified_words
def classify_and_show(self, input: str) -> Dict[str, List[str]]:
"""
固有表現の抽出と表示
Parameters
----------
input : str
固有表現抽出対象
Returns
-------
Dict[str, List[str]]
抽出結果の辞書
キーが分類ラベル、バリューがそのラベルの文字列のリスト
"""
classified_words = self.classify(input)
for label, words in classified_words.items():
print(f'{label: <10} {"、".join(words)}')
return classified_words
|