File size: 1,377 Bytes
6e757a0 bdbf479 6e757a0 5b40d74 6e757a0 de13b22 6e757a0 de13b22 6e757a0 de13b22 6e757a0 bdbf479 6e757a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import gradio as gr
import torch
from transformers import BertJapaneseTokenizer, BertForSequenceClassification
# 日本語の事前学習モデル
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
descriptions = '''BERTを用いたビジネス文書のネガポジ判定。文章を入力すると、その文章のネガポジ判定と判定の信頼度を表示します。'''
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
bert_sc_ = BertForSequenceClassification.from_pretrained("./models")
bert_sc = bert_sc_.to("cpu")
def func(text):
encoding = tokenizer(
text,
padding = "longest",
return_tensors="pt"
)
encoding = { k : v.cpu() for k, v in encoding.items()}
with torch.no_grad():
output = bert_sc(**encoding)
scores = output.logits.argmax(-1)
neg = torch.softmax(output.logits, dim=1).tolist()[0][0]
pos = torch.softmax(output.logits, dim=1).tolist()[0][1]
label = "ネガティブ" if scores.item()==0 else "ポジティブ"
cos = f"信頼度:{neg*100:.1f}%" if scores.item()==0 else f"信頼度:{pos*100:.1f}%"
return label,cos
app = gr.Interface(fn=func, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs=["label","label"], title="Sentiment Analysis of Business Documents", description=descriptions)
app.launch() |