iSpr commited on
Commit
55e98f1
โ€ข
1 Parent(s): f46a2b3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from asyncore import write
2
+ from pickletools import stringnl
3
+ import streamlit as st
4
+ import pandas as pd
5
+
6
+ # ๋ชจ๋ธ ์ค€๋น„ํ•˜๊ธฐ
7
+ from transformers import RobertaForSequenceClassification, AutoTokenizer
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import os
12
+
13
+ # ์ œ๋ชฉ ์ž…๋ ฅ
14
+ st.header('ํ•œ๊ตญํ‘œ์ค€์‚ฐ์—…๋ถ„๋ฅ˜ ์ž๋™์ฝ”๋”ฉ ์„œ๋น„์Šค')
15
+
16
+ # ์žฌ๋กœ๋“œ ์•ˆํ•˜๋„๋ก
17
+ @st.experimental_memo(max_entries=20)
18
+ def md_loading():
19
+ ## cpu
20
+ # device = torch.device('cpu')
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained('klue/roberta-base')
23
+ model = RobertaForSequenceClassification.from_pretrained('klue/roberta-base', num_labels=495)
24
+
25
+ model_checkpoint = 'upsampling_20.bin'
26
+ project_path = './'
27
+ output_model_file = os.path.join(project_path, model_checkpoint)
28
+
29
+ model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))
30
+
31
+ label_tbl = np.load('./label_table.npy')
32
+ loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
33
+
34
+ print('ready')
35
+
36
+ return tokenizer, model, label_tbl, loc_tbl
37
+
38
+ # ๋ชจ๋ธ ๋กœ๋“œ
39
+ tokenizer, model, label_tbl, loc_tbl = md_loading()
40
+
41
+
42
+ # ํ…์ŠคํŠธ input ๋ฐ•์Šค
43
+ business = st.text_input('์‚ฌ์—…์ฒด๋ช…', '์ถฉ์ฒญ์ง€๋ฐฉํ†ต๊ณ„์ฒญ').replace(',', '')
44
+ business_work = st.text_input('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ', 'ํ†ต๊ณ„์„œ๋น„์Šค ์ œ๊ณต ๋ฐ ์ง€์—ญํ†ต๊ณ„ ํ—ˆ๋ธŒ').replace(',', '')
45
+ work_department = st.text_input('๊ทผ๋ฌด๋ถ€์„œ', '์ง€์—ญํ†ต๊ณ„๊ณผ').replace(',', '')
46
+ work_position = st.text_input('์ง์ฑ…', '์ฃผ๋ฌด๊ด€').replace(',', '')
47
+ what_do_i = st.text_input('๋‚ด๊ฐ€ ํ•˜๋Š” ์ผ', 'ํ†ต๊ณ„๋ฐ์ดํ„ฐ์„ผํ„ฐ ์šด์˜').replace(',', '')
48
+
49
+ # md_input: ๋ชจ๋ธ์— ์ž…๋ ฅํ•  input ๊ฐ’ ์ •์˜
50
+ md_input = ', '.join([business, business_work, work_department, work_position, what_do_i])
51
+
52
+ ## ์ž„์‹œ ํ™•์ธ
53
+ # st.write(md_input)
54
+
55
+ # ๋ฒ„ํŠผ
56
+ if st.button('ํ™•์ธ'):
57
+ ## ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ˆ˜ํ–‰์‚ฌํ•ญ
58
+ ### ๋ชจ๋ธ ์‹คํ–‰
59
+ query_tokens = md_input.split(',')
60
+
61
+ input_ids = np.zeros(shape=[1, 64])
62
+ attention_mask = np.zeros(shape=[1, 64])
63
+
64
+ seq = '[CLS] '
65
+ try:
66
+ for i in range(5):
67
+ seq += query_tokens[i] + ' '
68
+ except:
69
+ None
70
+
71
+ tokens = tokenizer.tokenize(seq)
72
+ ids = tokenizer.convert_tokens_to_ids(tokens)
73
+
74
+ length = len(ids)
75
+ if length > 64:
76
+ length = 64
77
+
78
+ for i in range(length):
79
+ input_ids[0, i] = ids[i]
80
+ attention_mask[0, i] = 1
81
+
82
+ input_ids = torch.from_numpy(input_ids).type(torch.long)
83
+ attention_mask = torch.from_numpy(attention_mask).type(torch.long)
84
+
85
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None)
86
+ logits = outputs.logits
87
+
88
+ # # ๋‹จ๋… ์˜ˆ์ธก ์‹œ
89
+ # arg_idx = torch.argmax(logits, dim=1)
90
+ # print('arg_idx:', arg_idx)
91
+
92
+ # num_ans = label_tbl[arg_idx]
93
+ # str_ans = loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == num_ans].values
94
+
95
+ # ์ƒ์œ„ k๋ฒˆ์งธ๊นŒ์ง€ ์˜ˆ์ธก ์‹œ
96
+ k = 5
97
+ topk_idx = torch.topk(logits.flatten(), k).indices
98
+
99
+ num_ans_topk = label_tbl[topk_idx]
100
+ str_ans_topk = [loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == k] for k in num_ans_topk]
101
+
102
+ # print(num_ans, str_ans)
103
+ # print(num_ans_topk)
104
+
105
+ # print('์‚ฌ์—…์ฒด๋ช…:', query_tokens[0])
106
+ # print('์‚ฌ์—…์ฒด ํ•˜๋Š”์ผ:', query_tokens[1])
107
+ # print('๊ทผ๋ฌด๋ถ€์„œ:', query_tokens[2])
108
+ # print('์ง์ฑ…:', query_tokens[3])
109
+ # print('๋‚ด๊ฐ€ ํ•˜๋Š”์ผ:', query_tokens[4])
110
+ # print('์‚ฐ์—…์ฝ”๋“œ ๋ฐ ๋ถ„๋ฅ˜:', num_ans, str_ans)
111
+
112
+ # ans = ''
113
+ # ans1, ans2, ans3 = '', '', ''
114
+
115
+ ## ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ’ ์ถœ๋ ฅ
116
+ # st.write("์‚ฐ์—…์ฝ”๋“œ ๋ฐ ๋ถ„๋ฅ˜:", num_ans, str_ans[0])
117
+ # st.write("์„ธ๋ถ„๋ฅ˜ ์ฝ”๋“œ")
118
+ # for i in range(k):
119
+ # st.write(str(i+1) + '์ˆœ์œ„:', num_ans_topk[i], str_ans_topk[i].iloc[0])
120
+
121
+ # print(num_ans)
122
+ # print(str_ans, type(str_ans))
123
+
124
+ str_ans_topk_list = []
125
+ for i in range(k):
126
+ str_ans_topk_list.append(str_ans_topk[i].iloc[0])
127
+
128
+ # print(str_ans_topk_list)
129
+
130
+ ans_topk_df = pd.DataFrame({
131
+ 'NO': range(1, k+1),
132
+ '์„ธ๋ถ„๋ฅ˜ ์ฝ”๋“œ': num_ans_topk,
133
+ '์„ธ๋ถ„๋ฅ˜ ๋ช…์นญ': str_ans_topk_list
134
+ })
135
+ ans_topk_df = ans_topk_df.set_index('NO')
136
+
137
+ st.dataframe(ans_topk_df)