wolf4032's picture
Upload 6 files
ece09c3 verified
raw
history blame
3.06 kB
from typing import Dict, List
from transformers import BertJapaneseTokenizer, BertForTokenClassification, pipeline
from transformers.pipelines.token_classification import TokenClassificationPipeline
class NaturalLanguageProcessing:
"""
固有表現を抽出するクラス
model_dirにある言語モデルを使って固有表現抽出パイプラインを作成する
モデルに固有表現を抽出させるメソッドを持つ
Attributes
----------
_nlp : TokenClassificationPipeline
固有表現抽出パイプライン
"""
def __init__(self, model_dir: str):
"""
コンストラクタ
_nlpを作成する
Parameters
----------
model_dir : str
使用する言語モデルのディレクトリ
"""
self._nlp = NaturalLanguageProcessing._create(model_dir)
@staticmethod
def _create(model_dir: str) -> TokenClassificationPipeline:
"""
パイプラインの作成
Parameters
----------
model_dir : str
使用する言語モデルのディレクトリ
Returns
-------
TokenClassificationPipeline
固有表現抽出パイプライン
"""
tokenizer = BertJapaneseTokenizer.from_pretrained(model_dir)
model = BertForTokenClassification.from_pretrained(model_dir)
nlp = pipeline(
'token-classification',
model=model,
tokenizer=tokenizer,
aggregation_strategy='simple'
)
return nlp
def classify(self, input: str) -> Dict[str, List[str]]:
"""
固有表現の抽出
Parameters
----------
input : str
固有表現抽出対象
Returns
-------
Dict[str, List[str]]
抽出結果の辞書
キーが分類ラベル、バリューがそのラベルの文字列のリスト
"""
prediction_results:List[Dict[str, str | float | None]] = self._nlp(input)
classified_words = {}
for predict_result in prediction_results:
label = predict_result['entity_group']
word = predict_result['word']
if label not in classified_words:
classified_words[label] = []
classified_words[label].append(word.replace(' ', ''))
return classified_words
def classify_and_show(self, input: str) -> Dict[str, List[str]]:
"""
固有表現の抽出と表示
Parameters
----------
input : str
固有表現抽出対象
Returns
-------
Dict[str, List[str]]
抽出結果の辞書
キーが分類ラベル、バリューがそのラベルの文字列のリスト
"""
classified_words = self.classify(input)
for label, words in classified_words.items():
print(f'{label: <10} {"、".join(words)}')
return classified_words