Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import json | |
import pandas as pd | |
import numpy as np | |
import torch | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
st.set_page_config( | |
page_title="NER ๊ธฐ๋ฐ ๋ฏผ๊ฐ์ ๋ณด ์๋ณ", layout="wide", initial_sidebar_state="expanded" | |
) | |
def load_model(model_name): | |
model = AutoModelForTokenClassification.from_pretrained(model_name) | |
return model | |
st.title("๐ NER ๊ธฐ๋ฐ ๋ฏผ๊ฐ์ ๋ณด ์๋ณ๊ธฐ") | |
st.write("๋ฌธ์ฅ์ ์ ๋ ฅํ์๊ณ , CTRL+Enter(CMD+Enter)๋ฅผ ๋๋ฅด์ธ์ ๐ค") | |
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base") | |
model = load_model("QuoQA-NLP/konec-privacy") | |
model.eval() | |
default_value = "์์ง๋, ๋น๋จ ๊ฒ์ฌํ ๊ฑฐ ๊ฒฐ๊ณผ ๋์ค์ จ์ด์." | |
src_text = st.text_area( | |
"๊ฒ์ฌํ๊ณ ์ถ์ ๋ฌธ์ฅ์ ์ ๋ ฅํ์ธ์.", | |
default_value, | |
height=300, | |
max_chars=150, | |
) | |
def yield_df(default_value): | |
tokenized = tokenizer.encode(default_value) | |
print(tokenized) | |
output = model(input_ids=torch.tensor([tokenized])) | |
logits = output.logits | |
print(logits.size()) | |
# get prediction for each tokens for 17 classes | |
pred = logits.argmax(-1).squeeze().numpy() | |
print(pred) | |
class_map = { | |
"B-ADD": 0, | |
"I-ADD": 1, | |
"B-DN": 2, | |
"I-DN": 3, | |
"B-DT": 4, | |
"I-DT": 5, | |
"B-LC": 6, | |
"I-LC": 7, | |
"B-OG": 8, | |
"I-OG": 9, | |
"B-PS": 10, | |
"I-PS": 11, | |
"B-QT": 12, | |
"I-QT": 13, | |
"B-RL": 14, | |
"I-RL": 15, | |
"O": 16 | |
} | |
class_map_inverted = {v: k for k, v in class_map.items()} | |
# decode prediction | |
class_decoded = [class_map_inverted[p] for p in pred] | |
print(class_decoded) | |
label_map = { | |
"ADD": "์ฃผ์ ์ ๋ณด", | |
"DN": "์งํ ์ ๋ณด", | |
"DT": "๋ ์ง ์ ๋ณด", | |
"LC": "์ฅ์ ์ ๋ณด", | |
"OG": "๊ธฐ๊ด ์ ๋ณด", | |
"PS": "์ธ๋ช /๋ณ๋ช ์ ๋ณด", | |
"QT": "์๋ ์ ๋ณด", | |
"RL": "๊ด๊ณ ์ ๋ณด", | |
"O": "๋น๋ฏผ๊ฐ ์ ๋ณด" | |
} | |
# pair tokens with prediction | |
tokenized_text = tokenizer.convert_ids_to_tokens(tokenized) | |
list_result = [] | |
for token, pred in zip(tokenized_text, class_decoded): | |
splitted_pred = pred.split("-") | |
pred_class = splitted_pred[-1] | |
label = label_map[pred_class] | |
# print with 10 characters with spaces divided with | | |
result = {"ํํ์":token, "์์ ๋ผ๋ฒจ":label} | |
list_result.append(result) | |
df = pd.DataFrame(list_result) | |
# remove first and last row | |
df = df.iloc[1:-1] | |
return df | |
def convert_df(df:pd.DataFrame): | |
return df.to_csv(index=False).encode('utf-8') | |
def convert_json(df:pd.DataFrame): | |
result = df.to_json(orient="index") | |
parsed = json.loads(result) | |
json_string = json.dumps(parsed) | |
#st.json(json_string, expanded=True) | |
return json_string | |
filtering_map = { | |
"์ฃผ์ ์ ๋ณด": "[์ฃผ์]", | |
"์งํ ์ ๋ณด": "[์งํ]", | |
"๋ ์ง ์ ๋ณด": "[๋ ์ง]", | |
"์ฅ์ ์ ๋ณด": "[์ฅ์]", | |
"๊ธฐ๊ด ์ ๋ณด": "[๊ธฐ๊ด]", | |
"์ธ๋ช /๋ณ๋ช ์ ๋ณด": "[์ด๋ฆ]", | |
"์๋ ์ ๋ณด": "[์๋]", | |
"๊ด๊ณ ์ ๋ณด": "[๊ด๊ณ]", | |
"๋น๋ฏผ๊ฐ ์ ๋ณด": "[๋น๋ฏผ๊ฐ]" | |
} | |
if src_text == "": | |
st.warning("Please **enter text** for translation") | |
else: | |
df_result = yield_df(src_text) | |
st.markdown("### ํํฐ๋ง ๋ ๋ฌธ์ฅ") | |
display_result = "" | |
for index, row in df_result.iterrows(): | |
token_info = row["ํํ์"] | |
label_info = row["์์ ๋ผ๋ฒจ"] | |
if label_info != "๋น๋ฏผ๊ฐ ์ ๋ณด": | |
token_info = filtering_map[label_info] | |
if "##" in token_info: | |
token_info = token_info.replace("##", "") | |
else: | |
token_info = " " + token_info | |
display_result += token_info | |
st.write(display_result) | |
st.markdown("### ๋ถ๋ฅ๋ ๋จ์ด๋ค") | |
st.header("") | |
cs, c1, c2, c3, cLast = st.columns([0.75, 1.5, 1.5, 1.5, 0.75]) | |
st.table(df_result) | |
with c1: | |
#csvbutton = download_button(results, "results.csv", "๐ฅ Download .csv") | |
csvbutton = st.download_button(label="๐ฅ csv๋ก ๋ค์ด๋ก๋", data=convert_df(df_result), file_name= "results.csv", mime='text/csv', key='csv') | |
with c2: | |
#textbutton = download_button(results, "results.txt", "๐ฅ Download .txt") | |
textbutton = st.download_button(label="๐ฅ txt๋ก ๋ค์ด๋ก๋", data=convert_df(df_result), file_name= "results.text", mime='text/plain', key='text') | |
with c3: | |
#jsonbutton = download_button(results, "results.json", "๐ฅ Download .json") | |
jsonbutton = st.download_button(label="๐ฅ json์ผ๋ก ๋ค์ด๋ก๋", data=convert_json(df_result), file_name= "results.json", mime='application/json', key='json') | |
with st.expander("(์ฃผ) ์ฟผ์นด์์ด์์ด ๋ฐ๋ชจ ์ฌ์ฌ ๊ด๋ จ", expanded=True): | |
st.write( | |
""" | |
ํด๋น ๋ฐ๋ชจ๋ 2022๋ ๋ ๊ณผํ๊ธฐ์ ์ ๋ณดํต์ ๋ถ์ ์ฌ์์ผ๋ก ์ ๋ณดํต์ ์ฐ์ ์งํฅ์์ ์ง์์ ๋ฐ์ ์ํ๋ ์ฐ๊ตฌ์ | |
(๊ณผ์ ๋ฒํธ: A1504-22-1005) | |
""" | |
) |