|
|
|
import json |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
|
|
st.set_page_config( |
|
page_title="NER ๊ธฐ๋ฐ ๋ฏผ๊ฐ์ ๋ณด ์๋ณ", layout="wide", initial_sidebar_state="expanded" |
|
) |
|
|
|
@st.cache |
|
def load_model(model_name): |
|
model = AutoModelForTokenClassification.from_pretrained(model_name) |
|
return model |
|
|
|
|
|
st.title("๐ NER ๊ธฐ๋ฐ ๋ฏผ๊ฐ์ ๋ณด ์๋ณ๊ธฐ") |
|
st.write("๋ฌธ์ฅ์ ์
๋ ฅํ์๊ณ , CTRL+Enter(CMD+Enter)๋ฅผ ๋๋ฅด์ธ์ ๐ค") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base") |
|
model = load_model("QuoQA-NLP/konec-privacy") |
|
|
|
model.eval() |
|
|
|
|
|
default_value = "์์ง๋, ๋น๋จ ๊ฒ์ฌํ ๊ฑฐ ๊ฒฐ๊ณผ ๋์ค์
จ์ด์." |
|
|
|
src_text = st.text_area( |
|
"๊ฒ์ฌํ๊ณ ์ถ์ ๋ฌธ์ฅ์ ์
๋ ฅํ์ธ์.", |
|
default_value, |
|
height=300, |
|
max_chars=150, |
|
) |
|
|
|
|
|
def yield_df(default_value): |
|
tokenized = tokenizer.encode(default_value) |
|
print(tokenized) |
|
|
|
output = model(input_ids=torch.tensor([tokenized])) |
|
logits = output.logits |
|
print(logits.size()) |
|
|
|
|
|
pred = logits.argmax(-1).squeeze().numpy() |
|
print(pred) |
|
|
|
class_map = { |
|
"B-ADD": 0, |
|
"I-ADD": 1, |
|
"B-DN": 2, |
|
"I-DN": 3, |
|
"B-DT": 4, |
|
"I-DT": 5, |
|
"B-LC": 6, |
|
"I-LC": 7, |
|
"B-OG": 8, |
|
"I-OG": 9, |
|
"B-PS": 10, |
|
"I-PS": 11, |
|
"B-QT": 12, |
|
"I-QT": 13, |
|
"B-RL": 14, |
|
"I-RL": 15, |
|
"O": 16 |
|
} |
|
|
|
class_map_inverted = {v: k for k, v in class_map.items()} |
|
|
|
|
|
class_decoded = [class_map_inverted[p] for p in pred] |
|
print(class_decoded) |
|
|
|
label_map = { |
|
"ADD": "์ฃผ์ ์ ๋ณด", |
|
"DN": "์งํ ์ ๋ณด", |
|
"DT": "๋ ์ง ์ ๋ณด", |
|
"LC": "์ฅ์ ์ ๋ณด", |
|
"OG": "๊ธฐ๊ด ์ ๋ณด", |
|
"PS": "์ธ๋ช
/๋ณ๋ช
์ ๋ณด", |
|
"QT": "์๋ ์ ๋ณด", |
|
"RL": "๊ด๊ณ ์ ๋ณด", |
|
"O": "๋น๋ฏผ๊ฐ ์ ๋ณด" |
|
} |
|
|
|
|
|
|
|
tokenized_text = tokenizer.convert_ids_to_tokens(tokenized) |
|
list_result = [] |
|
for token, pred in zip(tokenized_text, class_decoded): |
|
splitted_pred = pred.split("-") |
|
pred_class = splitted_pred[-1] |
|
label = label_map[pred_class] |
|
|
|
result = {"ํํ์":token, "์์ ๋ผ๋ฒจ":label} |
|
list_result.append(result) |
|
|
|
df = pd.DataFrame(list_result) |
|
|
|
df = df.iloc[1:-1] |
|
return df |
|
|
|
def convert_df(df:pd.DataFrame): |
|
return df.to_csv(index=False).encode('utf-8') |
|
|
|
def convert_json(df:pd.DataFrame): |
|
result = df.to_json(orient="index") |
|
parsed = json.loads(result) |
|
json_string = json.dumps(parsed) |
|
|
|
return json_string |
|
|
|
|
|
|
|
filtering_map = { |
|
"์ฃผ์ ์ ๋ณด": "[์ฃผ์]", |
|
"์งํ ์ ๋ณด": "[์งํ]", |
|
"๋ ์ง ์ ๋ณด": "[๋ ์ง]", |
|
"์ฅ์ ์ ๋ณด": "[์ฅ์]", |
|
"๊ธฐ๊ด ์ ๋ณด": "[๊ธฐ๊ด]", |
|
"์ธ๋ช
/๋ณ๋ช
์ ๋ณด": "[์ด๋ฆ]", |
|
"์๋ ์ ๋ณด": "[์๋]", |
|
"๊ด๊ณ ์ ๋ณด": "[๊ด๊ณ]", |
|
"๋น๋ฏผ๊ฐ ์ ๋ณด": "[๋น๋ฏผ๊ฐ]" |
|
} |
|
|
|
if src_text == "": |
|
st.warning("Please **enter text** for translation") |
|
else: |
|
df_result = yield_df(src_text) |
|
st.markdown("### ํํฐ๋ง ๋ ๋ฌธ์ฅ") |
|
|
|
display_result = "" |
|
for index, row in df_result.iterrows(): |
|
token_info = row["ํํ์"] |
|
label_info = row["์์ ๋ผ๋ฒจ"] |
|
if label_info != "๋น๋ฏผ๊ฐ ์ ๋ณด": |
|
token_info = filtering_map[label_info] |
|
|
|
if "##" in token_info: |
|
token_info = token_info.replace("##", "") |
|
else: |
|
token_info = " " + token_info |
|
display_result += token_info |
|
|
|
st.write(display_result) |
|
|
|
st.markdown("### ๋ถ๋ฅ๋ ๋จ์ด๋ค") |
|
st.header("") |
|
cs, c1, c2, c3, cLast = st.columns([0.75, 1.5, 1.5, 1.5, 0.75]) |
|
|
|
st.table(df_result) |
|
with c1: |
|
|
|
csvbutton = st.download_button(label="๐ฅ csv๋ก ๋ค์ด๋ก๋", data=convert_df(df_result), file_name= "results.csv", mime='text/csv', key='csv') |
|
with c2: |
|
|
|
textbutton = st.download_button(label="๐ฅ txt๋ก ๋ค์ด๋ก๋", data=convert_df(df_result), file_name= "results.text", mime='text/plain', key='text') |
|
with c3: |
|
|
|
jsonbutton = st.download_button(label="๐ฅ json์ผ๋ก ๋ค์ด๋ก๋", data=convert_json(df_result), file_name= "results.json", mime='application/json', key='json') |
|
|
|
|
|
|
|
with st.expander("(์ฃผ) ์ฟผ์นด์์ด์์ด ๋ฐ๋ชจ ์ฌ์ฌ ๊ด๋ จ", expanded=True): |
|
|
|
st.write( |
|
""" |
|
ํด๋น ๋ฐ๋ชจ๋ 2022๋
๋ ๊ณผํ๊ธฐ์ ์ ๋ณดํต์ ๋ถ์ ์ฌ์์ผ๋ก ์ ๋ณดํต์ ์ฐ์
์งํฅ์์ ์ง์์ ๋ฐ์ ์ํ๋ ์ฐ๊ตฌ์ |
|
(๊ณผ์ ๋ฒํธ: A1504-22-1005) |
|
""" |
|
) |