iSpr's picture
Update app.py
84b0359
raw
history blame
4.03 kB
from asyncore import write
from pickletools import stringnl
import streamlit as st
import pandas as pd
# ๋ชจ๋ธ ์ค€๋น„ํ•˜๊ธฐ
from transformers import RobertaForSequenceClassification, AutoTokenizer
import numpy as np
import pandas as pd
import torch
import os
[theme]
base="dark"
primaryColor="purple"
# ์ œ๋ชฉ ์ž…๋ ฅ
st.header('ํ•œ๊ตญํ‘œ์ค€์‚ฐ์—…๋ถ„๋ฅ˜ ์ž๋™์ฝ”๋”ฉ ์„œ๋น„์Šค')
# ์žฌ๋กœ๋“œ ์•ˆํ•˜๋„๋ก
@st.experimental_memo(max_entries=20)
def md_loading():
## cpu
# device = torch.device('cpu')
tokenizer = AutoTokenizer.from_pretrained('klue/roberta-base')
model = RobertaForSequenceClassification.from_pretrained('klue/roberta-base', num_labels=495)
model_checkpoint = 'upsampling_20.bin'
project_path = './'
output_model_file = os.path.join(project_path, model_checkpoint)
model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))
label_tbl = np.load('./label_table.npy')
loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
print('ready')
return tokenizer, model, label_tbl, loc_tbl
# ๋ชจ๋ธ ๋กœ๋“œ
tokenizer, model, label_tbl, loc_tbl = md_loading()
# ํ…์ŠคํŠธ input ๋ฐ•์Šค
business = st.text_input('์‚ฌ์—…์ฒด๋ช…').replace(',', '')
business_work = st.text_input('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ').replace(',', '')
work_department = st.text_input('๊ทผ๋ฌด๋ถ€์„œ').replace(',', '')
work_position = st.text_input('์ง์ฑ…').replace(',', '')
what_do_i = st.text_input('๋‚ด๊ฐ€ ํ•˜๋Š” ์ผ').replace(',', '')
# md_input: ๋ชจ๋ธ์— ์ž…๋ ฅํ•  input ๊ฐ’ ์ •์˜
md_input = ', '.join([business, business_work, work_department, work_position, what_do_i])
## ์ž„์‹œ ํ™•์ธ
# st.write(md_input)
# ๋ฒ„ํŠผ
if st.button('ํ™•์ธ'):
## ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ˆ˜ํ–‰์‚ฌํ•ญ
### ๋ชจ๋ธ ์‹คํ–‰
query_tokens = md_input.split(',')
input_ids = np.zeros(shape=[1, 64])
attention_mask = np.zeros(shape=[1, 64])
seq = '[CLS] '
try:
for i in range(5):
seq += query_tokens[i] + ' '
except:
None
tokens = tokenizer.tokenize(seq)
ids = tokenizer.convert_tokens_to_ids(tokens)
length = len(ids)
if length > 64:
length = 64
for i in range(length):
input_ids[0, i] = ids[i]
attention_mask[0, i] = 1
input_ids = torch.from_numpy(input_ids).type(torch.long)
attention_mask = torch.from_numpy(attention_mask).type(torch.long)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None)
logits = outputs.logits
# # ๋‹จ๋… ์˜ˆ์ธก ์‹œ
# arg_idx = torch.argmax(logits, dim=1)
# print('arg_idx:', arg_idx)
# num_ans = label_tbl[arg_idx]
# str_ans = loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == num_ans].values
# ์ƒ์œ„ k๋ฒˆ์งธ๊นŒ์ง€ ์˜ˆ์ธก ์‹œ
k = 5
topk_idx = torch.topk(logits.flatten(), k).indices
num_ans_topk = label_tbl[topk_idx]
str_ans_topk = [loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == k] for k in num_ans_topk]
# print(num_ans, str_ans)
# print(num_ans_topk)
# print('์‚ฌ์—…์ฒด๋ช…:', query_tokens[0])
# print('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ:', query_tokens[1])
# print('๊ทผ๋ฌด๋ถ€์„œ:', query_tokens[2])
# print('์ง์ฑ…:', query_tokens[3])
# print('๋‚ด๊ฐ€ ํ•˜๋Š”์ผ:', query_tokens[4])
# print('์‚ฐ์—…์ฝ”๋“œ ๋ฐ ๋ถ„๋ฅ˜:', num_ans, str_ans)
# ans = ''
# ans1, ans2, ans3 = '', '', ''
## ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ’ ์ถœ๋ ฅ
# st.write("์‚ฐ์—…์ฝ”๋“œ ๋ฐ ๋ถ„๋ฅ˜:", num_ans, str_ans[0])
# st.write("์„ธ๋ถ„๋ฅ˜ ์ฝ”๋“œ")
# for i in range(k):
# st.write(str(i+1) + '์ˆœ์œ„:', num_ans_topk[i], str_ans_topk[i].iloc[0])
# print(num_ans)
# print(str_ans, type(str_ans))
str_ans_topk_list = []
for i in range(k):
str_ans_topk_list.append(str_ans_topk[i].iloc[0])
# print(str_ans_topk_list)
ans_topk_df = pd.DataFrame({
'NO': range(1, k+1),
'์„ธ๋ถ„๋ฅ˜ ์ฝ”๋“œ': num_ans_topk,
'์„ธ๋ถ„๋ฅ˜ ๋ช…์นญ': str_ans_topk_list
})
ans_topk_df = ans_topk_df.set_index('NO')
st.dataframe(ans_topk_df)