from asyncore import write from pickletools import stringnl import streamlit as st import pandas as pd # 모델 준비하기 from transformers import RobertaForSequenceClassification, AutoTokenizer import numpy as np import pandas as pd import torch import os [theme] base="dark" primaryColor="purple" # 제목 입력 st.header('한국표준산업분류 자동코딩 서비스') # 재로드 안하도록 @st.experimental_memo(max_entries=20) def md_loading(): ## cpu # device = torch.device('cpu') tokenizer = AutoTokenizer.from_pretrained('klue/roberta-base') model = RobertaForSequenceClassification.from_pretrained('klue/roberta-base', num_labels=495) model_checkpoint = 'upsampling_20.bin' project_path = './' output_model_file = os.path.join(project_path, model_checkpoint) model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu'))) label_tbl = np.load('./label_table.npy') loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8') print('ready') return tokenizer, model, label_tbl, loc_tbl # 모델 로드 tokenizer, model, label_tbl, loc_tbl = md_loading() # 텍스트 input 박스 business = st.text_input('사업체명').replace(',', '') business_work = st.text_input('사업체 하는일').replace(',', '') work_department = st.text_input('근무부서').replace(',', '') work_position = st.text_input('직책').replace(',', '') what_do_i = st.text_input('내가 하는 일').replace(',', '') # md_input: 모델에 입력할 input 값 정의 md_input = ', '.join([business, business_work, work_department, work_position, what_do_i]) ## 임시 확인 # st.write(md_input) # 버튼 if st.button('확인'): ## 버튼 클릭 시 수행사항 ### 모델 실행 query_tokens = md_input.split(',') input_ids = np.zeros(shape=[1, 64]) attention_mask = np.zeros(shape=[1, 64]) seq = '[CLS] ' try: for i in range(5): seq += query_tokens[i] + ' ' except: None tokens = tokenizer.tokenize(seq) ids = tokenizer.convert_tokens_to_ids(tokens) length = len(ids) if length > 64: length = 64 for i in range(length): input_ids[0, i] = ids[i] attention_mask[0, i] = 1 input_ids = torch.from_numpy(input_ids).type(torch.long) attention_mask = torch.from_numpy(attention_mask).type(torch.long) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None) logits = outputs.logits # # 단독 예측 시 # arg_idx = torch.argmax(logits, dim=1) # print('arg_idx:', arg_idx) # num_ans = label_tbl[arg_idx] # str_ans = loc_tbl['항목명'][loc_tbl['코드'] == num_ans].values # 상위 k번째까지 예측 시 k = 5 topk_idx = torch.topk(logits.flatten(), k).indices num_ans_topk = label_tbl[topk_idx] str_ans_topk = [loc_tbl['항목명'][loc_tbl['코드'] == k] for k in num_ans_topk] # print(num_ans, str_ans) # print(num_ans_topk) # print('사업체명:', query_tokens[0]) # print('사업체 하는일:', query_tokens[1]) # print('근무부서:', query_tokens[2]) # print('직책:', query_tokens[3]) # print('내가 하는일:', query_tokens[4]) # print('산업코드 및 분류:', num_ans, str_ans) # ans = '' # ans1, ans2, ans3 = '', '', '' ## 모델 결과값 출력 # st.write("산업코드 및 분류:", num_ans, str_ans[0]) # st.write("세분류 코드") # for i in range(k): # st.write(str(i+1) + '순위:', num_ans_topk[i], str_ans_topk[i].iloc[0]) # print(num_ans) # print(str_ans, type(str_ans)) str_ans_topk_list = [] for i in range(k): str_ans_topk_list.append(str_ans_topk[i].iloc[0]) # print(str_ans_topk_list) ans_topk_df = pd.DataFrame({ 'NO': range(1, k+1), '세분류 코드': num_ans_topk, '세분류 명칭': str_ans_topk_list }) ans_topk_df = ans_topk_df.set_index('NO') st.dataframe(ans_topk_df)