Spaces:
Paused
Paused
Upload 2 files
Browse files- bert_classification(0.87).pt +3 -0
- utils.py +156 -0
bert_classification(0.87).pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5847ffe4878ab7c57a22308310065d10fb603d2817eaffe12af543832cbd3610
|
3 |
+
size 711548626
|
utils.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""utils(2).ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1snWVRieogxGIRp-UsTCZWjLM5ir5KQxB
|
8 |
+
"""
|
9 |
+
|
10 |
+
import re
|
11 |
+
import nltk
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
from nltk.tokenize import TweetTokenizer
|
16 |
+
from nltk.stem import WordNetLemmatizer
|
17 |
+
from nltk.corpus import stopwords
|
18 |
+
from nltk.corpus import wordnet
|
19 |
+
from transformers import BertTokenizer
|
20 |
+
from keras.preprocessing.sequence import pad_sequences
|
21 |
+
|
22 |
+
nltk.download('stopwords')
|
23 |
+
|
24 |
+
stopword_list = nltk.corpus.stopwords.words('english')
|
25 |
+
stopword_list.remove('no')
|
26 |
+
stopword_list.remove('not')
|
27 |
+
|
28 |
+
nltk.download('punkt')
|
29 |
+
nltk.download('averaged_perceptron_tagger')
|
30 |
+
nltk.download('wordnet')
|
31 |
+
|
32 |
+
tokenizer = TweetTokenizer()
|
33 |
+
lemmatizer = WordNetLemmatizer()
|
34 |
+
tokenizer_B = BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=True)
|
35 |
+
|
36 |
+
device = torch.device("cuda")
|
37 |
+
|
38 |
+
# wordnet ๋ชจ๋์ ์ฌ์ฉํ์ฌ ๋จ์ด์ ํ์ฌ(POS, Part of Speech)๋ฅผ ๊ฐ์ ธ์ค๋ ํจ์
|
39 |
+
def get_wordnet_pos(word):
|
40 |
+
"""Map POS tag to first character lemmatize() accepts"""
|
41 |
+
tag = nltk.pos_tag([word])[0][1][0].upper()
|
42 |
+
tag_dict = {"J": wordnet.ADJ, #ํ์ฉ์ฌ
|
43 |
+
"N": wordnet.NOUN, #๋ช
์ฌ
|
44 |
+
"V": wordnet.VERB, #๋์ฌ
|
45 |
+
"R": wordnet.ADV} #๋ถ์ฌ
|
46 |
+
|
47 |
+
return tag_dict.get(tag, wordnet.NOUN)
|
48 |
+
|
49 |
+
def get_wordnet_pos(word):
|
50 |
+
"""Map POS tag to first character lemmatize() accepts"""
|
51 |
+
tag = nltk.pos_tag([word])[0][1][0].upper()
|
52 |
+
tag_dict = {"J": wordnet.ADJ, #ํ์ฉ์ฌ
|
53 |
+
"N": wordnet.NOUN, #๋ช
์ฌ
|
54 |
+
"V": wordnet.VERB, #๋์ฌ
|
55 |
+
"R": wordnet.ADV} #๋ถ์ฌ
|
56 |
+
|
57 |
+
return tag_dict.get(tag, wordnet.NOUN)
|
58 |
+
|
59 |
+
# ์ ์ฒ๋ฆฌ ํจ์
|
60 |
+
def pre_data(data):
|
61 |
+
|
62 |
+
#์๋ฌธ์
|
63 |
+
df2 = data.lower().strip()
|
64 |
+
|
65 |
+
#ํ ํฐํโTweetTokenizer ์ฌ์ฉ
|
66 |
+
df_token = tokenizer.tokenize(df2)
|
67 |
+
|
68 |
+
#@์์ด๋ โ ์ ๊ฑฐ
|
69 |
+
df_IDdel = []
|
70 |
+
for word in df_token:
|
71 |
+
if '@' not in word:
|
72 |
+
df_IDdel.append(word)
|
73 |
+
|
74 |
+
#๋ค์ ๋ฌธ์ฅ,..
|
75 |
+
df_IDdel_sen = ' '.join(df_IDdel)
|
76 |
+
|
77 |
+
#์์ด ์๋ ๋ฌธ์๋ค ๊ณต๋ฐฑ์ผ๋ก ์ ํ
|
78 |
+
df_eng = re.sub("[^a-zA-Z]", " ", df_IDdel_sen)
|
79 |
+
|
80 |
+
#๋ฐ๋ณต๋ ์ฒ ์ ์ง์ฐ๊ธฐ (์ต๋ 2๊ฐ๊น์ง ๊ฐ๋ฅ)
|
81 |
+
df_rep_list = []
|
82 |
+
for i, e in enumerate(df_eng):
|
83 |
+
if i > 1 and e == df_eng[i - 2] and e == df_eng[i - 1]:
|
84 |
+
df_rep_list.append('')
|
85 |
+
else:
|
86 |
+
df_rep_list.append(e)
|
87 |
+
df_rep = ''.join(df_rep_list)
|
88 |
+
#์ฐ์๋ ๊ณต๋ฐฑ ์ ๋ฆฌ
|
89 |
+
df_rep = re.sub(r'\s+', ' ', df_rep)
|
90 |
+
|
91 |
+
#ํ์ ์ด ์ถ์ถ(lemmatizer)
|
92 |
+
df_lemma = [lemmatizer.lemmatize(w, get_wordnet_pos(w)) for w in nltk.word_tokenize(df_rep)]
|
93 |
+
|
94 |
+
#๋ถ์ฉ์ด ์ ๊ฑฐ
|
95 |
+
df_clean = [w for w in df_lemma if not w in stopword_list]
|
96 |
+
|
97 |
+
if len(df_clean) == 0:
|
98 |
+
df_clean = 'NC' #NC=No Category - ๋ฆฌ์คํธ๊ฐ ๋น์ด์์๋ ์ฌ์ฉํ๋ ๋ฌธ์์ด, ์๋ฏธ์๋ ๋จ์ดX
|
99 |
+
else: df_clean = ' '.join(df_clean)
|
100 |
+
|
101 |
+
return df_clean
|
102 |
+
|
103 |
+
# ์
๋ ฅ ๋ฐ์ดํฐ ๋ณํ
|
104 |
+
def convert_input_data(sentences):
|
105 |
+
|
106 |
+
# BERT์ ํ ํฌ๋์ด์ ๋ก ๋ฌธ์ฅ์ ํ ํฐ์ผ๋ก ๋ถ๋ฆฌ
|
107 |
+
tokenized_texts = [tokenizer_B.tokenize(sent) for sent in sentences]
|
108 |
+
|
109 |
+
# ์
๋ ฅ ํ ํฐ์ ์ต๋ ์ํ์ค ๊ธธ์ด
|
110 |
+
MAX_LEN = 80
|
111 |
+
|
112 |
+
# ํ ํฐ์ ์ซ์ ์ธ๋ฑ์ค๋ก ๋ณํ
|
113 |
+
input_ids = [tokenizer_B.convert_tokens_to_ids(x) for x in tokenized_texts]
|
114 |
+
|
115 |
+
# ๋ฌธ์ฅ์ MAX_LEN ๊ธธ์ด์ ๋ง๊ฒ ์๋ฅด๊ณ , ๋ชจ์๋ ๋ถ๋ถ์ ํจ๋ฉ 0์ผ๋ก ์ฑ์
|
116 |
+
input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
|
117 |
+
|
118 |
+
# ์ดํ
์
๋ง์คํฌ ์ด๊ธฐํ
|
119 |
+
attention_masks = []
|
120 |
+
|
121 |
+
# ์ดํ
์
๋ง์คํฌ๋ฅผ ํจ๋ฉ์ด ์๋๋ฉด 1, ํจ๋ฉ์ด๋ฉด 0์ผ๋ก ์ค์
|
122 |
+
# ํจ๋ฉ ๋ถ๋ถ์ BERT ๋ชจ๋ธ์์ ์ดํ
์
์ ์ํํ์ง ์์ ์๋ ํฅ์
|
123 |
+
for seq in input_ids:
|
124 |
+
seq_mask = [float(i>0) for i in seq]
|
125 |
+
attention_masks.append(seq_mask)
|
126 |
+
|
127 |
+
# ๋ฐ์ดํฐ๋ฅผ ํ์ดํ ์น์ ํ
์๋ก ๋ณํ
|
128 |
+
inputs = torch.tensor(input_ids)
|
129 |
+
masks = torch.tensor(attention_masks)
|
130 |
+
|
131 |
+
return inputs, masks
|
132 |
+
|
133 |
+
# ๋ฌธ์ฅ ํ
์คํธ
|
134 |
+
def test_sentences(sentences, load_model):
|
135 |
+
|
136 |
+
# ๋ฌธ์ฅ์ ์
๋ ฅ ๋ฐ์ดํฐ๋ก ๋ณํ
|
137 |
+
inputs, masks = convert_input_data(sentences)
|
138 |
+
|
139 |
+
# ๋ฐ์ดํฐ๋ฅผ GPU์ ๋ฃ์
|
140 |
+
b_input_ids = inputs.to(device)
|
141 |
+
b_input_mask = masks.to(device)
|
142 |
+
|
143 |
+
# ๊ทธ๋๋์ธํธ ๊ณ์ฐ ์ํจ
|
144 |
+
with torch.no_grad():
|
145 |
+
# Forward ์ํ
|
146 |
+
outputs = load_model(b_input_ids,
|
147 |
+
token_type_ids=None,
|
148 |
+
attention_mask=b_input_mask)
|
149 |
+
|
150 |
+
# ๋ก์ค ๊ตฌํจ
|
151 |
+
logits = outputs[0]
|
152 |
+
|
153 |
+
# CPU๋ก ๋ฐ์ดํฐ ์ด๋
|
154 |
+
logits = logits.detach().cpu().numpy()
|
155 |
+
|
156 |
+
return logits
|