pax-dare-lab's picture
First Commit
b7012f5
raw
history blame
1.06 kB
from transformers import AutoModelForQuestionAnswering, BertJapaneseTokenizer
import streamlit as st
import torch
model_name = 'KoichiYasuoka/bert-base-japanese-wikipedia-ud-head'
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = BertJapaneseTokenizer.from_pretrained(model_name)
context = "私の名前は山田です。趣味は動画鑑賞とショッピングです。年齢は30歳です。出身は大阪府です。仕事は医者です。"
st.title("AI assistant")
user_input = st.text_input("You:", "")
if st.button("Send"):
inputs = tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt")
input_ids = inputs["input_ids"].tolist()[0]
output = model(**inputs)
answer_start = torch.argmax(output.start_logits)
answer_end = torch.argmax(output.end_logits) + 1
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
answer = answer.replace(' ', '')
st.text_area("Chatbot:", answer, height=100)